Repository: drakkan/sftpgo Branch: main Commit: dda97dd75835 Files: 405 Total size: 12.6 MB Directory structure: gitextract_3usfk2f2/ ├── .cirrus.yml ├── .github/ │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.yml │ │ ├── config.yml │ │ └── feature_request.yml │ ├── PULL_REQUEST_TEMPLATE.md │ ├── dependabot.yml │ └── workflows/ │ ├── .editorconfig │ ├── codeql.yml │ ├── development.yml │ ├── docker.yml │ └── release.yml ├── .gitignore ├── .golangci.yml ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── Dockerfile ├── Dockerfile.alpine ├── Dockerfile.distroless ├── LICENSE ├── NOTICE ├── README.md ├── SECURITY.md ├── crowdin.yml ├── docker/ │ └── scripts/ │ └── download-plugins.sh ├── examples/ │ ├── OTP/ │ │ └── authy/ │ │ ├── README.md │ │ ├── checkpwd/ │ │ │ ├── README.md │ │ │ ├── go.mod │ │ │ └── main.go │ │ ├── extauth/ │ │ │ ├── README.md │ │ │ ├── go.mod │ │ │ └── main.go │ │ └── keyint/ │ │ ├── README.md │ │ ├── go.mod │ │ └── main.go │ ├── backup/ │ │ ├── README.md │ │ └── backup │ ├── bulkupdate/ │ │ ├── README.md │ │ └── bulkuserupdate │ ├── convertusers/ │ │ ├── README.md │ │ └── convertusers │ ├── ldapauth/ │ │ ├── README.md │ │ ├── go.mod │ │ ├── go.sum │ │ └── main.go │ ├── ldapauthserver/ │ │ ├── README.md │ │ ├── cmd/ │ │ │ ├── root.go │ │ │ └── serve.go │ │ ├── config/ │ │ │ └── config.go │ │ ├── go.mod │ │ ├── go.sum │ │ ├── httpd/ │ │ │ ├── auth.go │ │ │ ├── httpd.go │ │ │ ├── ldapauth.go │ │ │ ├── models.go │ │ │ └── tlsutils.go │ │ ├── ldapauth.toml │ │ ├── logger/ │ │ │ ├── logger.go │ │ │ ├── request_logger.go │ │ │ └── sync_wrapper.go │ │ ├── main.go │ │ └── utils/ │ │ ├── utils.go │ │ └── version.go │ ├── php-activedirectory-http-server/ │ │ └── README.md │ └── quotascan/ │ ├── README.md │ └── scanuserquota ├── go.mod ├── go.sum ├── init/ │ ├── com.github.drakkan.sftpgo.plist │ └── sftpgo.service ├── internal/ │ ├── acme/ │ │ ├── account.go │ │ └── acme.go │ ├── bundle/ │ │ └── bundle.go │ ├── cmd/ │ │ ├── acme.go │ │ ├── gen.go │ │ ├── gencompletion.go │ │ ├── genman.go │ │ ├── initprovider.go │ │ ├── install_windows.go │ │ ├── ping.go │ │ ├── portable.go │ │ ├── portable_disabled.go │ │ ├── reload_windows.go │ │ ├── resetprovider.go │ │ ├── resetpwd.go │ │ ├── revertprovider.go │ │ ├── root.go │ │ ├── rotatelogs_windows.go │ │ ├── serve.go │ │ ├── service_windows.go │ │ ├── signals_unix.go │ │ ├── signals_windows.go │ │ ├── smtptest.go │ │ ├── start_windows.go │ │ ├── status_windows.go │ │ ├── stop_windows.go │ │ └── uninstall_windows.go │ ├── command/ │ │ ├── command.go │ │ └── command_test.go │ ├── common/ │ │ ├── actions.go │ │ ├── actions_test.go │ │ ├── clientsmap.go │ │ ├── clientsmap_test.go │ │ ├── common.go │ │ ├── common_test.go │ │ ├── connection.go │ │ ├── connection_test.go │ │ ├── dataretention.go │ │ ├── dataretention_test.go │ │ ├── defender.go │ │ ├── defender_test.go │ │ ├── defenderdb.go │ │ ├── defenderdb_test.go │ │ ├── defendermem.go │ │ ├── eventmanager.go │ │ ├── eventmanager_test.go │ │ ├── eventscheduler.go │ │ ├── httpauth.go │ │ ├── httpauth_test.go │ │ ├── protocol_test.go │ │ ├── ratelimiter.go │ │ ├── ratelimiter_test.go │ │ ├── tlsutils.go │ │ ├── tlsutils_test.go │ │ ├── transfer.go │ │ ├── transfer_test.go │ │ ├── transferschecker.go │ │ └── transferschecker_test.go │ ├── config/ │ │ ├── config.go │ │ ├── config_darwin.go │ │ ├── config_fallback.go │ │ ├── config_linux.go │ │ └── config_test.go │ ├── dataprovider/ │ │ ├── actions.go │ │ ├── admin.go │ │ ├── apikey.go │ │ ├── bolt.go │ │ ├── bolt_disabled.go │ │ ├── cachedpassword.go │ │ ├── cacheduser.go │ │ ├── configs.go │ │ ├── dataprovider.go │ │ ├── eventrule.go │ │ ├── group.go │ │ ├── iplist.go │ │ ├── memory.go │ │ ├── mysql.go │ │ ├── mysql_disabled.go │ │ ├── node.go │ │ ├── pgsql.go │ │ ├── pgsql_disabled.go │ │ ├── quotaupdater.go │ │ ├── role.go │ │ ├── scheduler.go │ │ ├── session.go │ │ ├── share.go │ │ ├── sqlcommon.go │ │ ├── sqlite.go │ │ ├── sqlite_disabled.go │ │ ├── sqlqueries.go │ │ ├── unixcrypt.go │ │ ├── unixcrypt_disabled.go │ │ └── user.go │ ├── ftpd/ │ │ ├── cryptfs_test.go │ │ ├── ftpd.go │ │ ├── ftpd_test.go │ │ ├── handler.go │ │ ├── internal_test.go │ │ ├── server.go │ │ └── transfer.go │ ├── httpclient/ │ │ └── httpclient.go │ ├── httpd/ │ │ ├── api_admin.go │ │ ├── api_configs.go │ │ ├── api_defender.go │ │ ├── api_eventrule.go │ │ ├── api_events.go │ │ ├── api_folder.go │ │ ├── api_group.go │ │ ├── api_http_user.go │ │ ├── api_iplist.go │ │ ├── api_keys.go │ │ ├── api_maintenance.go │ │ ├── api_mfa.go │ │ ├── api_quota.go │ │ ├── api_retention.go │ │ ├── api_role.go │ │ ├── api_shares.go │ │ ├── api_user.go │ │ ├── api_utils.go │ │ ├── auth_utils.go │ │ ├── file.go │ │ ├── flash.go │ │ ├── flash_test.go │ │ ├── handler.go │ │ ├── httpd.go │ │ ├── httpd_test.go │ │ ├── internal_test.go │ │ ├── middleware.go │ │ ├── oauth2.go │ │ ├── oauth2_test.go │ │ ├── oidc.go │ │ ├── oidc_test.go │ │ ├── oidcmanager.go │ │ ├── resetcode.go │ │ ├── resources.go │ │ ├── resources_embedded.go │ │ ├── server.go │ │ ├── token.go │ │ ├── web.go │ │ ├── webadmin.go │ │ ├── webclient.go │ │ ├── webtask.go │ │ └── webtask_test.go │ ├── httpdtest/ │ │ ├── httpdtest.go │ │ └── httpfsimpl.go │ ├── jwt/ │ │ ├── jwt.go │ │ └── jwt_test.go │ ├── kms/ │ │ ├── basesecret.go │ │ ├── builtin.go │ │ ├── kms.go │ │ └── local.go │ ├── logger/ │ │ ├── hclog.go │ │ ├── lego.go │ │ ├── logger.go │ │ ├── mail.go │ │ ├── request_logger.go │ │ ├── slog.go │ │ └── sync_wrapper.go │ ├── metric/ │ │ ├── metric.go │ │ └── metric_disabled.go │ ├── mfa/ │ │ ├── mfa.go │ │ ├── mfa_test.go │ │ └── totp.go │ ├── plugin/ │ │ ├── auth.go │ │ ├── ipfilter.go │ │ ├── kms.go │ │ ├── notifier.go │ │ ├── plugin.go │ │ ├── searcher.go │ │ └── util.go │ ├── service/ │ │ ├── service.go │ │ ├── service_portable.go │ │ ├── service_windows.go │ │ ├── signals_unix.go │ │ └── signals_windows.go │ ├── sftpd/ │ │ ├── cryptfs_test.go │ │ ├── handler.go │ │ ├── httpfs_test.go │ │ ├── internal_test.go │ │ ├── lister.go │ │ ├── scp.go │ │ ├── server.go │ │ ├── sftpd.go │ │ ├── sftpd_test.go │ │ ├── ssh_cmd.go │ │ └── transfer.go │ ├── smtp/ │ │ ├── oauth2.go │ │ └── smtp.go │ ├── telemetry/ │ │ ├── router.go │ │ ├── telemetry.go │ │ └── telemetry_test.go │ ├── util/ │ │ ├── errors.go │ │ ├── i18n.go │ │ ├── resources.go │ │ ├── resources_embedded.go │ │ ├── util.go │ │ ├── util_fallback.go │ │ └── util_unix.go │ ├── version/ │ │ └── version.go │ ├── vfs/ │ │ ├── azblobfs.go │ │ ├── azblobfs_disabled.go │ │ ├── cryptfs.go │ │ ├── fileinfo.go │ │ ├── filesystem.go │ │ ├── folder.go │ │ ├── gcsfs.go │ │ ├── gcsfs_disabled.go │ │ ├── httpfs.go │ │ ├── osfs.go │ │ ├── s3fs.go │ │ ├── s3fs_disabled.go │ │ ├── sftpfs.go │ │ ├── statvfs_fallback.go │ │ ├── statvfs_linux.go │ │ ├── statvfs_unix.go │ │ ├── sys_unix.go │ │ ├── sys_windows.go │ │ └── vfs.go │ └── webdavd/ │ ├── file.go │ ├── handler.go │ ├── internal_test.go │ ├── mimecache.go │ ├── server.go │ ├── webdavd.go │ └── webdavd_test.go ├── main.go ├── openapi/ │ ├── httpfs.yaml │ ├── openapi.yaml │ └── swagger-ui/ │ ├── index.css │ ├── index.html │ ├── swagger-initializer.js │ ├── swagger-ui-bundle.js │ ├── swagger-ui-standalone-preset.js │ └── swagger-ui.css ├── pkgs/ │ ├── build.sh │ ├── choco/ │ │ ├── sftpgo.nuspec │ │ └── tools/ │ │ └── ChocolateyInstall.ps1 │ ├── debian/ │ │ ├── changelog │ │ ├── compat │ │ ├── control │ │ ├── copyright │ │ ├── patches/ │ │ │ ├── config.diff │ │ │ └── series │ │ ├── postinst │ │ ├── rules │ │ ├── sftpgo-docs.docs │ │ ├── sftpgo.dirs │ │ ├── sftpgo.install │ │ ├── sftpgo.install.arm64 │ │ ├── sftpgo.install.armhf │ │ ├── sftpgo.install.ppc64el │ │ └── source/ │ │ └── format │ └── scripts/ │ ├── deb/ │ │ ├── postinstall.sh │ │ ├── postremove.sh │ │ └── preremove.sh │ └── rpm/ │ ├── postinstall │ ├── postremove │ └── preremove ├── sftpgo.json ├── static/ │ ├── assets/ │ │ ├── css/ │ │ │ └── style.bundle.css │ │ ├── js/ │ │ │ └── scripts.bundle.js │ │ └── plugins/ │ │ ├── custom/ │ │ │ ├── datatables/ │ │ │ │ ├── datatables.bundle.css │ │ │ │ └── datatables.bundle.js │ │ │ ├── flatpickr/ │ │ │ │ └── l10n/ │ │ │ │ ├── de.js │ │ │ │ ├── es.js │ │ │ │ ├── fr.js │ │ │ │ ├── it.js │ │ │ │ └── zh.js │ │ │ └── formrepeater/ │ │ │ └── formrepeater.bundle.js │ │ └── global/ │ │ ├── plugins.bundle.css │ │ └── plugins.bundle.js │ └── locales/ │ ├── de/ │ │ └── translation.json │ ├── en/ │ │ └── translation.json │ ├── es/ │ │ └── translation.json │ ├── fr/ │ │ └── translation.json │ ├── it/ │ │ └── translation.json │ └── zh-CN/ │ └── translation.json ├── templates/ │ ├── common/ │ │ ├── base.html │ │ ├── baselogin.html │ │ ├── changepassword.html │ │ ├── forgot-password.html │ │ ├── login.html │ │ ├── message.html │ │ ├── reset-password.html │ │ ├── twofactor-recovery.html │ │ └── twofactor.html │ ├── email/ │ │ ├── password-expiration.html │ │ └── reset-password.html │ ├── webadmin/ │ │ ├── admin.html │ │ ├── admins.html │ │ ├── adminsetup.html │ │ ├── base.html │ │ ├── configs.html │ │ ├── connections.html │ │ ├── defender.html │ │ ├── eventaction.html │ │ ├── eventactions.html │ │ ├── eventrule.html │ │ ├── eventrules.html │ │ ├── events.html │ │ ├── folder.html │ │ ├── folders.html │ │ ├── fsconfig.html │ │ ├── group.html │ │ ├── groups.html │ │ ├── iplist.html │ │ ├── iplists.html │ │ ├── maintenance.html │ │ ├── mfa.html │ │ ├── profile.html │ │ ├── role.html │ │ ├── roles.html │ │ ├── status.html │ │ ├── user.html │ │ └── users.html │ └── webclient/ │ ├── base.html │ ├── editfile.html │ ├── files.html │ ├── mfa.html │ ├── profile.html │ ├── share.html │ ├── sharedownload.html │ ├── sharelogin.html │ ├── shares.html │ ├── shareupload.html │ └── viewpdf.html ├── tests/ │ ├── eventsearcher/ │ │ ├── go.mod │ │ ├── go.sum │ │ └── main.go │ └── ipfilter/ │ ├── go.mod │ ├── go.sum │ └── main.go └── windows-installer/ ├── LICENSE_with_NOTICE.txt ├── README.txt └── sftpgo.iss ================================================ FILE CONTENTS ================================================ ================================================ FILE: .cirrus.yml ================================================ freebsd_task: name: FreeBSD matrix: - name: FreeBSD 14.3 freebsd_instance: image_family: freebsd-14-3 pkginstall_script: - pkg update -f - pkg install -y go125 - pkg install -y git setup_script: - ln -s /usr/local/bin/go125 /usr/local/bin/go - pw groupadd sftpgo - pw useradd sftpgo -g sftpgo -w none -m - mkdir /home/sftpgo/sftpgo - cp -R . /home/sftpgo/sftpgo - chown -R sftpgo:sftpgo /home/sftpgo/sftpgo compile_script: - su sftpgo -c 'cd ~/sftpgo && go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' - su sftpgo -c 'cd ~/sftpgo/tests/eventsearcher && go build -trimpath -ldflags "-s -w" -o eventsearcher' - su sftpgo -c 'cd ~/sftpgo/tests/ipfilter && go build -trimpath -ldflags "-s -w" -o ipfilter' check_script: - su sftpgo -c 'cd ~/sftpgo && ./sftpgo initprovider && ./sftpgo resetprovider --force' test_script: - su sftpgo -c 'cd ~/sftpgo && go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 20m ./... -coverprofile=coverage.txt -covermode=atomic' ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: [drakkan] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.yml ================================================ name: Open Source Bug Report description: "Submit a report and help us improve SFTPGo" title: "[Bug]: " labels: ["bug"] body: - type: markdown attributes: value: | ### 👍 Thank you for contributing to our project! Before asking for help please check our [support policy](https://github.com/drakkan/sftpgo?tab=readme-ov-file#support). If you are a [commercial user](https://sftpgo.com/) please contact us using the dedicated [email address](mailto:support@sftpgo.com). If you'd like to contribute code, please make sure to read and understand our [Contributor License Agreement (CLA)](https://sftpgo.com/cla.html). You’ll be asked to accept it when submitting a pull request. - type: checkboxes id: before-posting attributes: label: "⚠️ This issue respects the following points: ⚠️" description: All conditions are **required**. options: - label: This is a **bug**, not a question or a configuration issue. required: true - label: This issue is **not** already reported on Github _(I've searched it)_. required: true - type: textarea id: bug-description attributes: label: Bug description description: | Provide a description of the bug you're experiencing. Don't just expect someone will guess what your specific problem is and provide full details. validations: required: true - type: textarea id: reproduce attributes: label: Steps to reproduce description: | Describe the steps to reproduce the bug. The better your description is the fastest you'll get an _(accurate)_ answer. value: | 1. 2. 3. validations: required: true - type: textarea id: expected-behavior attributes: label: Expected behavior description: Describe what you expected to happen instead. validations: required: true - type: input id: version attributes: label: SFTPGo version validations: required: true - type: input id: data-provider attributes: label: Data provider validations: required: true - type: dropdown id: install-method attributes: label: Installation method description: | Select installation method you've used. _Describe the method in the "Additional info" section if you chose "Other"._ options: - "Community Docker image" - "Community Deb package" - "Community RPM package" - "Other" validations: required: true - type: textarea attributes: label: Configuration description: "Describe your customizations to the configuration: both config file changes and overrides via environment variables" value: config validations: required: true - type: textarea id: logs attributes: label: Relevant log output description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. render: shell - type: dropdown id: usecase attributes: label: What are you using SFTPGo for? description: We'd like to understand your SFTPGo usecase more multiple: true options: - "Private user, home usecase (home backup/VPS)" - "Professional user, 1 person business" - "Small business (3-person firm with file exchange?)" - "Medium business" - "Enterprise" validations: required: true - type: textarea id: additional-info attributes: label: Additional info description: Any additional information related to the issue. ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false contact_links: - name: Commercial Support url: https://sftpgo.com/ about: > If you need Professional support, so your reports are prioritized and resolved more quickly. - name: GitHub Community Discussions url: https://github.com/drakkan/sftpgo/discussions about: Please ask and answer questions here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.yml ================================================ name: 🚀 Feature request description: Suggest an idea for SFTPGo labels: ["suggestion"] body: - type: markdown attributes: value: | ### 👍 Thank you for contributing to our project! Before asking for help please check our [support policy](https://github.com/drakkan/sftpgo?tab=readme-ov-file#support). If you are a [commercial user](https://sftpgo.com/) please contact us using the dedicated [email address](mailto:support@sftpgo.com). If you'd like to contribute code, please make sure to read and understand our [Contributor License Agreement (CLA)](https://sftpgo.com/cla.html). You’ll be asked to accept it when submitting a pull request. - type: textarea attributes: label: Is your feature request related to a problem? Please describe. description: A clear and concise description of what the problem is. validations: required: false - type: textarea attributes: label: Describe the solution you'd like description: A clear and concise description of what you want to happen. validations: required: true - type: textarea attributes: label: Describe alternatives you've considered description: A clear and concise description of any alternative solutions or features you've considered. validations: required: false - type: dropdown id: usecase attributes: label: What are you using SFTPGo for? description: We'd like to understand your SFTPGo usecase more multiple: true options: - "Private user, home usecase (home backup/VPS)" - "Professional user, 1 person business" - "Small business (3-person firm with file exchange?)" - "Medium business" - "Enterprise" validations: required: true - type: textarea attributes: label: Additional context description: Add any other context or screenshots about the feature request here. validations: required: false ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ # Checklist for Pull Requests - [ ] Have you signed the [Contributor License Agreement](https://sftpgo.com/cla.html)? --- ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: #- package-ecosystem: "gomod" # directory: "/" # schedule: # interval: "weekly" # open-pull-requests-limit: 2 - package-ecosystem: "docker" directory: "/" schedule: interval: "weekly" open-pull-requests-limit: 2 - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" open-pull-requests-limit: 2 ================================================ FILE: .github/workflows/.editorconfig ================================================ [*.yml] indent_size = 2 ================================================ FILE: .github/workflows/codeql.yml ================================================ name: "Code scanning - action" on: push: pull_request: schedule: - cron: '30 1 * * 6' jobs: CodeQL-Build: runs-on: ubuntu-latest permissions: security-events: write steps: - name: Checkout repository uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v6 with: go-version: '1.25' - name: Initialize CodeQL uses: github/codeql-action/init@v4 with: languages: go - name: Autobuild uses: github/codeql-action/autobuild@v4 - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v4 ================================================ FILE: .github/workflows/development.yml ================================================ name: CI on: push: branches: [main] pull_request: permissions: id-token: write contents: read jobs: test-deploy: name: Test and deploy runs-on: ${{ matrix.os }} strategy: matrix: go: ['1.26'] os: [ubuntu-latest, macos-latest] steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v6 with: go-version: ${{ matrix.go }} - name: Build for Linux/macOS x86_64 run: | go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo cd tests/eventsearcher go build -trimpath -ldflags "-s -w" -o eventsearcher cd - cd tests/ipfilter go build -trimpath -ldflags "-s -w" -o ipfilter cd - ./sftpgo initprovider ./sftpgo resetprovider --force - name: Build for macOS arm64 if: startsWith(matrix.os, 'macos-') == true run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64 - name: Run test cases using SQLite provider run: go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -coverprofile=coverage.txt -covermode=atomic - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: files: ./coverage.txt fail_ci_if_error: false token: ${{ secrets.CODECOV_TOKEN }} - name: Run test cases using bolt provider run: | go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/config -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/common -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/httpd -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 8m ./internal/sftpd -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/ftpd -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/webdavd -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/telemetry -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/mfa -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/command -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: bolt SFTPGO_DATA_PROVIDER__NAME: 'sftpgo_bolt.db' - name: Run test cases using memory provider run: go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: memory SFTPGO_DATA_PROVIDER__NAME: '' - name: Prepare build artifact for macOS if: startsWith(matrix.os, 'macos-') == true run: | mkdir -p output/{init,bash_completion,zsh_completion} cp sftpgo output/sftpgo_x86_64 cp sftpgo_arm64 output/ cp sftpgo.json output/ cp -r templates output/ cp -r static output/ cp -r openapi output/ cp init/com.github.drakkan.sftpgo.plist output/init/ ./sftpgo gen completion bash > output/bash_completion/sftpgo ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo ./sftpgo gen man -d output/man/man1 gzip output/man/man1/* - name: Upload build artifact if: startsWith(matrix.os, 'ubuntu-') != true uses: actions/upload-artifact@v7 with: name: sftpgo-${{ matrix.os }}-go-${{ matrix.go }} path: output test-deploy-windows: name: Test and deploy Windows runs-on: windows-latest steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v6 with: go-version: '1.26' - name: Run test cases using SQLite provider run: | cd tests/eventsearcher go build -trimpath -ldflags "-s -w" -o eventsearcher.exe cd ../.. cd tests/ipfilter go build -trimpath -ldflags "-s -w" -o ipfilter.exe cd ../.. go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -coverprofile=coverage.txt -covermode=atomic - name: Run test cases using bolt provider run: | go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/config -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/common -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/httpd -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 8m ./internal/sftpd -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/ftpd -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/webdavd -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/telemetry -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/mfa -covermode=atomic go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/command -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: bolt SFTPGO_DATA_PROVIDER__NAME: 'sftpgo_bolt.db' - name: Run test cases using memory provider run: go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: memory SFTPGO_DATA_PROVIDER__NAME: '' - name: Build run: | $GIT_COMMIT = (git describe --always --abbrev=8 --dirty) | Out-String $DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String $LATEST_TAG = ((git describe --tags $(git rev-list --tags --max-count=1)) | Out-String).Trim() $REV_LIST=$LATEST_TAG+"..HEAD" $COMMITS_FROM_TAG= ((git rev-list $REV_LIST --count) | Out-String).Trim() $FILE_VERSION = $LATEST_TAG.substring(1) + "." + $COMMITS_FROM_TAG go install github.com/tc-hib/go-winres@latest go-winres simply --arch amd64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o sftpgo.exe mkdir arm64 $Env:CGO_ENABLED='0' $Env:GOOS='windows' $Env:GOARCH='arm64' go-winres simply --arch arm64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nosqlite -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe mkdir x86 $Env:GOARCH='386' go-winres simply --arch 386 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nosqlite -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\x86\sftpgo.exe Remove-Item Env:\CGO_ENABLED Remove-Item Env:\GOOS Remove-Item Env:\GOARCH - name: Initialize data provider run: | rm sftpgo.db ./sftpgo initprovider shell: bash - name: Prepare Windows installers if: ${{ github.event_name != 'pull_request' }} run: | choco install innosetup Remove-Item -LiteralPath "output" -Force -Recurse -ErrorAction Ignore mkdir output copy .\sftpgo.exe .\output copy .\sftpgo.json .\output copy .\sftpgo.db .\output copy .\LICENSE .\output\LICENSE.txt copy .\NOTICE .\output\NOTICE.txt mkdir output\templates xcopy .\templates .\output\templates\ /E mkdir output\static xcopy .\static .\output\static\ /E mkdir output\openapi xcopy .\openapi .\output\openapi\ /E $LATEST_TAG = ((git describe --tags $(git rev-list --tags --max-count=1)) | Out-String).Trim() $REV_LIST=$LATEST_TAG+"..HEAD" $COMMITS_FROM_TAG= ((git rev-list $REV_LIST --count) | Out-String).Trim() $Env:SFTPGO_ISS_DEV_VERSION = $LATEST_TAG + "." + $COMMITS_FROM_TAG iscc .\windows-installer\sftpgo.iss rm .\output\sftpgo.exe rm .\output\sftpgo.db copy .\arm64\sftpgo.exe .\output (Get-Content .\output\sftpgo.json).replace('"sqlite"', '"bolt"') | Set-Content .\output\sftpgo.json $Env:SFTPGO_DATA_PROVIDER__DRIVER='bolt' $Env:SFTPGO_DATA_PROVIDER__NAME='.\output\sftpgo.db' .\sftpgo.exe initprovider Remove-Item Env:\SFTPGO_DATA_PROVIDER__DRIVER Remove-Item Env:\SFTPGO_DATA_PROVIDER__NAME $Env:SFTPGO_ISS_ARCH='arm64' iscc .\windows-installer\sftpgo.iss rm .\output\sftpgo.exe copy .\x86\sftpgo.exe .\output $Env:SFTPGO_ISS_ARCH='x86' iscc .\windows-installer\sftpgo.iss - name: Upload Windows installer x86_64 artifact if: ${{ github.event_name != 'pull_request' }} uses: actions/upload-artifact@v7 with: name: sftpgo_windows_installer_x86_64 path: ./sftpgo_windows_x86_64.exe - name: Upload Windows installer arm64 artifact if: ${{ github.event_name != 'pull_request' }} uses: actions/upload-artifact@v7 with: name: sftpgo_windows_installer_arm64 path: ./sftpgo_windows_arm64.exe - name: Upload Windows installer x86 artifact if: ${{ github.event_name != 'pull_request' }} uses: actions/upload-artifact@v7 with: name: sftpgo_windows_installer_x86 path: ./sftpgo_windows_x86.exe - name: Prepare build artifact for Windows run: | Remove-Item -LiteralPath "output" -Force -Recurse -ErrorAction Ignore mkdir output copy .\sftpgo.exe .\output mkdir output\arm64 copy .\arm64\sftpgo.exe .\output\arm64 mkdir output\x86 copy .\x86\sftpgo.exe .\output\x86 copy .\sftpgo.json .\output (Get-Content .\output\sftpgo.json).replace('"sqlite"', '"bolt"') | Set-Content .\output\sftpgo.json mkdir output\templates xcopy .\templates .\output\templates\ /E mkdir output\static xcopy .\static .\output\static\ /E mkdir output\openapi xcopy .\openapi .\output\openapi\ /E - name: Upload build artifact uses: actions/upload-artifact@v7 with: name: sftpgo-windows-portable path: output test-build-flags: name: Test build flags runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Set up Go uses: actions/setup-go@v6 with: go-version: '1.26' - name: Build run: | go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nogcs,nos3,noportable,nobolt,nomysql,nopgsql,nosqlite,nometrics,noazblob,unixcrypt -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo ./sftpgo -v cp -r openapi static templates internal/bundle/ go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,bundle -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo ./sftpgo -v test-postgresql-mysql-crdb: name: Test with PgSQL/MySQL/Cockroach runs-on: ubuntu-latest services: postgres: image: postgres:latest env: POSTGRES_PASSWORD: postgres POSTGRES_DB: sftpgo options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 ports: - 5432:5432 mariadb: image: mariadb:latest env: MYSQL_ROOT_PASSWORD: mysql MYSQL_DATABASE: sftpgo MYSQL_USER: sftpgo MYSQL_PASSWORD: sftpgo options: >- --health-cmd "mariadb-admin status -h 127.0.0.1 -P 3306 -u root -p$MYSQL_ROOT_PASSWORD" --health-interval 10s --health-timeout 5s --health-retries 6 ports: - 3307:3306 mysql: image: mysql:latest env: MYSQL_ROOT_PASSWORD: mysql MYSQL_DATABASE: sftpgo MYSQL_USER: sftpgo MYSQL_PASSWORD: sftpgo options: >- --health-cmd "mysqladmin status -h 127.0.0.1 -P 3306 -u root -p$MYSQL_ROOT_PASSWORD" --health-interval 10s --health-timeout 5s --health-retries 6 ports: - 3308:3306 steps: - uses: actions/checkout@v6 - name: Set up Go uses: actions/setup-go@v6 with: go-version: '1.26' - name: Build run: | go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo cd tests/eventsearcher go build -trimpath -ldflags "-s -w" -o eventsearcher cd - cd tests/ipfilter go build -trimpath -ldflags "-s -w" -o ipfilter cd - - name: Run tests using MySQL provider run: | ./sftpgo initprovider ./sftpgo resetprovider --force go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: mysql SFTPGO_DATA_PROVIDER__NAME: sftpgo SFTPGO_DATA_PROVIDER__HOST: localhost SFTPGO_DATA_PROVIDER__PORT: 3308 SFTPGO_DATA_PROVIDER__USERNAME: sftpgo SFTPGO_DATA_PROVIDER__PASSWORD: sftpgo - name: Run tests using PostgreSQL provider run: | ./sftpgo initprovider ./sftpgo resetprovider --force go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: postgresql SFTPGO_DATA_PROVIDER__NAME: sftpgo SFTPGO_DATA_PROVIDER__HOST: localhost SFTPGO_DATA_PROVIDER__PORT: 5432 SFTPGO_DATA_PROVIDER__USERNAME: postgres SFTPGO_DATA_PROVIDER__PASSWORD: postgres - name: Run tests using MariaDB provider run: | ./sftpgo initprovider ./sftpgo resetprovider --force go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: mysql SFTPGO_DATA_PROVIDER__NAME: sftpgo SFTPGO_DATA_PROVIDER__HOST: localhost SFTPGO_DATA_PROVIDER__PORT: 3307 SFTPGO_DATA_PROVIDER__USERNAME: sftpgo SFTPGO_DATA_PROVIDER__PASSWORD: sftpgo SFTPGO_DATA_PROVIDER__SQL_TABLES_PREFIX: prefix_ - name: Run tests using CockroachDB provider run: | docker run --rm --name crdb --health-cmd "curl -I http://127.0.0.1:8080" --health-interval 10s --health-timeout 5s --health-retries 6 -p 26257:26257 -d cockroachdb/cockroach:latest start-single-node --insecure --listen-addr :26257 sleep 10 docker exec crdb cockroach sql --insecure -e 'create database "sftpgo"' ./sftpgo initprovider ./sftpgo resetprovider --force go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic docker stop crdb env: SFTPGO_DATA_PROVIDER__DRIVER: cockroachdb SFTPGO_DATA_PROVIDER__NAME: sftpgo SFTPGO_DATA_PROVIDER__HOST: localhost SFTPGO_DATA_PROVIDER__PORT: 26257 SFTPGO_DATA_PROVIDER__USERNAME: root SFTPGO_DATA_PROVIDER__PASSWORD: SFTPGO_DATA_PROVIDER__TARGET_SESSION_ATTRS: any SFTPGO_DATA_PROVIDER__SQL_TABLES_PREFIX: prefix_ build-linux-packages: name: Build Linux packages runs-on: ubuntu-latest strategy: matrix: include: - arch: amd64 distro: ubuntu:18.04 go: latest go-arch: amd64 - arch: aarch64 distro: ubuntu18.04 go: latest go-arch: arm64 - arch: ppc64le distro: ubuntu18.04 go: latest go-arch: ppc64le - arch: armv7 distro: ubuntu18.04 go: latest go-arch: arm7 steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Get commit SHA id: get_commit run: echo "COMMIT=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT shell: bash - name: Build on amd64 if: ${{ matrix.arch == 'amd64' }} run: | echo '#!/bin/bash' > build.sh echo '' >> build.sh echo 'set -e' >> build.sh echo 'apt-get update -q -y' >> build.sh echo 'apt-get install -q -y curl gcc' >> build.sh if [ ${{ matrix.go }} == 'latest' ] then echo 'GO_VERSION=$(curl -L https://go.dev/VERSION?m=text | head -n 1)' >> build.sh else echo 'GO_VERSION=${{ matrix.go }}' >> build.sh fi echo 'GO_DOWNLOAD_ARCH=${{ matrix.go-arch }}' >> build.sh echo 'curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/${GO_VERSION}.linux-${GO_DOWNLOAD_ARCH}.tar.gz' >> build.sh echo 'tar -C /usr/local -xzf go.tar.gz' >> build.sh echo 'export PATH=$PATH:/usr/local/go/bin' >> build.sh echo 'go version' >> build.sh echo 'cd /usr/local/src' >> build.sh echo 'go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_commit.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' >> build.sh chmod 755 build.sh docker run --rm --name ubuntu-build --mount type=bind,source=`pwd`,target=/usr/local/src ${{ matrix.distro }} /usr/local/src/build.sh mkdir -p output/{init,bash_completion,zsh_completion} cp sftpgo.json output/ cp -r templates output/ cp -r static output/ cp -r openapi output/ cp init/sftpgo.service output/init/ ./sftpgo gen completion bash > output/bash_completion/sftpgo ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo ./sftpgo gen man -d output/man/man1 gzip output/man/man1/* cp sftpgo output/ - uses: uraimo/run-on-arch-action@v3 if: ${{ matrix.arch != 'amd64' }} name: Build for ${{ matrix.arch }} id: build with: arch: ${{ matrix.arch }} distro: ${{ matrix.distro }} setup: | mkdir -p "${PWD}/output" dockerRunArgs: | --volume "${PWD}/output:/output" shell: /bin/bash install: | apt-get update -q -y apt-get install -q -y curl gcc if [ ${{ matrix.go }} == 'latest' ] then GO_VERSION=$(curl -L https://go.dev/VERSION?m=text | head -n 1) else GO_VERSION=${{ matrix.go }} fi GO_DOWNLOAD_ARCH=${{ matrix.go-arch }} if [ ${{ matrix.arch}} == 'armv7' ] then GO_DOWNLOAD_ARCH=armv6l fi curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/${GO_VERSION}.linux-${GO_DOWNLOAD_ARCH}.tar.gz tar -C /usr/local -xzf go.tar.gz run: | export PATH=$PATH:/usr/local/go/bin go version if [ ${{ matrix.arch}} == 'armv7' ] then export GOARM=7 fi go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_commit.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo mkdir -p output/{init,bash_completion,zsh_completion} cp sftpgo.json output/ cp -r templates output/ cp -r static output/ cp -r openapi output/ cp init/sftpgo.service output/init/ ./sftpgo gen completion bash > output/bash_completion/sftpgo ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo ./sftpgo gen man -d output/man/man1 gzip output/man/man1/* cp sftpgo output/ - name: Upload build artifact uses: actions/upload-artifact@v7 with: name: sftpgo-linux-${{ matrix.arch }}-go-${{ matrix.go }} path: output - name: Build Packages id: build_linux_pkgs run: | export NFPM_ARCH=${{ matrix.go-arch }} cd pkgs ./build.sh PKG_VERSION=$(cat dist/version) echo "pkg-version=${PKG_VERSION}" >> $GITHUB_OUTPUT - name: Upload Debian Package uses: actions/upload-artifact@v7 with: name: sftpgo-${{ steps.build_linux_pkgs.outputs.pkg-version }}-${{ matrix.go-arch }}-deb path: pkgs/dist/deb/* - name: Upload RPM Package uses: actions/upload-artifact@v7 with: name: sftpgo-${{ steps.build_linux_pkgs.outputs.pkg-version }}-${{ matrix.go-arch }}-rpm path: pkgs/dist/rpm/* golangci-lint: name: golangci-lint runs-on: ubuntu-latest steps: - name: Set up Go uses: actions/setup-go@v6 with: go-version: '1.26' - uses: actions/checkout@v6 - name: Run golangci-lint uses: golangci/golangci-lint-action@v9 with: version: latest ================================================ FILE: .github/workflows/docker.yml ================================================ name: Docker on: #schedule: # - cron: '0 4 * * *' # everyday at 4:00 AM UTC push: branches: - main tags: - v* pull_request: jobs: build: name: Build runs-on: ${{ matrix.os }} strategy: matrix: os: - ubuntu-latest docker_pkg: - debian - alpine optional_deps: - true - false include: - os: ubuntu-latest docker_pkg: distroless optional_deps: false - os: ubuntu-latest docker_pkg: debian-plugins optional_deps: true steps: - name: Checkout uses: actions/checkout@v6 - name: Gather image information id: info run: | VERSION=noop DOCKERFILE=Dockerfile MINOR="" MAJOR="" FEATURES="nopgxregisterdefaulttypes,disable_grpc_modules" if [ "${{ github.event_name }}" = "schedule" ]; then VERSION=nightly elif [[ $GITHUB_REF == refs/tags/* ]]; then VERSION=${GITHUB_REF#refs/tags/} elif [[ $GITHUB_REF == refs/heads/* ]]; then VERSION=$(echo ${GITHUB_REF#refs/heads/} | sed -r 's#/+#-#g') if [ "${{ github.event.repository.default_branch }}" = "$VERSION" ]; then VERSION=edge fi elif [[ $GITHUB_REF == refs/pull/* ]]; then VERSION=pr-${{ github.event.number }} fi if [[ $VERSION =~ ^v[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$ ]]; then MINOR=${VERSION%.*} MAJOR=${MINOR%.*} fi VERSION_SLIM="${VERSION}-slim" if [[ $DOCKER_PKG == alpine ]]; then VERSION="${VERSION}-alpine" VERSION_SLIM="${VERSION}-slim" DOCKERFILE=Dockerfile.alpine elif [[ $DOCKER_PKG == distroless ]]; then VERSION="${VERSION}-distroless" VERSION_SLIM="${VERSION}-slim" DOCKERFILE=Dockerfile.distroless FEATURES="${FEATURES},nosqlite" elif [[ $DOCKER_PKG == debian-plugins ]]; then VERSION="${VERSION}-plugins" VERSION_SLIM="${VERSION}-slim" FEATURES="${FEATURES},unixcrypt" elif [[ $DOCKER_PKG == debian ]]; then FEATURES="${FEATURES},unixcrypt" fi DOCKER_IMAGES=("drakkan/sftpgo" "ghcr.io/drakkan/sftpgo") TAGS="${DOCKER_IMAGES[0]}:${VERSION}" TAGS_SLIM="${DOCKER_IMAGES[0]}:${VERSION_SLIM}" for DOCKER_IMAGE in ${DOCKER_IMAGES[@]}; do if [[ ${DOCKER_IMAGE} != ${DOCKER_IMAGES[0]} ]]; then TAGS="${TAGS},${DOCKER_IMAGE}:${VERSION}" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${VERSION_SLIM}" fi if [[ $GITHUB_REF == refs/tags/* ]]; then if [[ $DOCKER_PKG == debian ]]; then if [[ -n $MAJOR && -n $MINOR ]]; then TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR},${DOCKER_IMAGE}:${MAJOR}" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-slim,${DOCKER_IMAGE}:${MAJOR}-slim" fi TAGS="${TAGS},${DOCKER_IMAGE}:latest" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:slim" elif [[ $DOCKER_PKG == distroless ]]; then if [[ -n $MAJOR && -n $MINOR ]]; then TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-distroless,${DOCKER_IMAGE}:${MAJOR}-distroless" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-distroless-slim,${DOCKER_IMAGE}:${MAJOR}-distroless-slim" fi TAGS="${TAGS},${DOCKER_IMAGE}:distroless" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:distroless-slim" elif [[ $DOCKER_PKG == debian-plugins ]]; then if [[ -n $MAJOR && -n $MINOR ]]; then TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-plugins,${DOCKER_IMAGE}:${MAJOR}-plugins" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-plugins-slim,${DOCKER_IMAGE}:${MAJOR}-plugins-slim" fi TAGS="${TAGS},${DOCKER_IMAGE}:plugins" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:plugins-slim" else if [[ -n $MAJOR && -n $MINOR ]]; then TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-alpine,${DOCKER_IMAGE}:${MAJOR}-alpine" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-alpine-slim,${DOCKER_IMAGE}:${MAJOR}-alpine-slim" fi TAGS="${TAGS},${DOCKER_IMAGE}:alpine" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:alpine-slim" fi fi done if [[ $OPTIONAL_DEPS == true ]]; then echo "version=${VERSION}" >> $GITHUB_OUTPUT echo "tags=${TAGS}" >> $GITHUB_OUTPUT echo "full=true" >> $GITHUB_OUTPUT else echo "version=${VERSION_SLIM}" >> $GITHUB_OUTPUT echo "tags=${TAGS_SLIM}" >> $GITHUB_OUTPUT echo "full=false" >> $GITHUB_OUTPUT fi if [[ $DOCKER_PKG == debian-plugins ]]; then echo "plugins=true" >> $GITHUB_OUTPUT else echo "plugins=false" >> $GITHUB_OUTPUT fi echo "dockerfile=${DOCKERFILE}" >> $GITHUB_OUTPUT echo "features=${FEATURES}" >> $GITHUB_OUTPUT echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT echo "sha=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT env: DOCKER_PKG: ${{ matrix.docker_pkg }} OPTIONAL_DEPS: ${{ matrix.optional_deps }} - name: Set up QEMU uses: docker/setup-qemu-action@v4 - name: Set up builder uses: docker/setup-buildx-action@v4 id: builder - name: Login to Docker Hub uses: docker/login-action@v4 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} if: ${{ github.event_name != 'pull_request' }} - name: Login to GitHub Container Registry uses: docker/login-action@v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} if: ${{ github.event_name != 'pull_request' }} - name: Build and push uses: docker/build-push-action@v7 with: context: . builder: ${{ steps.builder.outputs.name }} file: ./${{ steps.info.outputs.dockerfile }} platforms: linux/amd64,linux/arm64,linux/ppc64le,linux/arm/v7 push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.info.outputs.tags }} build-args: | COMMIT_SHA=${{ steps.info.outputs.sha }} INSTALL_OPTIONAL_PACKAGES=${{ steps.info.outputs.full }} DOWNLOAD_PLUGINS=${{ steps.info.outputs.plugins }} FEATURES=${{ steps.info.outputs.features }} labels: | org.opencontainers.image.title=SFTPGo org.opencontainers.image.description=Full-featured and highly configurable file transfer server: SFTP, HTTP/S,FTP/S, WebDAV org.opencontainers.image.url=https://github.com/drakkan/sftpgo org.opencontainers.image.documentation=https://github.com/drakkan/sftpgo/blob/${{ github.sha }}/docker/README.md org.opencontainers.image.source=https://github.com/drakkan/sftpgo org.opencontainers.image.version=${{ steps.info.outputs.version }} org.opencontainers.image.created=${{ steps.info.outputs.created }} org.opencontainers.image.revision=${{ github.sha }} org.opencontainers.image.licenses=AGPL-3.0-only ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: push: tags: 'v*' permissions: id-token: write contents: write env: GO_VERSION: 1.25.8 jobs: prepare-sources-with-deps: name: Prepare sources with deps runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Set up Go uses: actions/setup-go@v6 with: go-version: ${{ env.GO_VERSION }} - name: Get SFTPGo version id: get_version run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT - name: Prepare release run: | go mod vendor echo "${SFTPGO_VERSION}" > VERSION.txt echo "${GITHUB_SHA::8}" >> VERSION.txt tar cJvf sftpgo_${SFTPGO_VERSION}_src_with_deps.tar.xz * env: SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }} - name: Upload build artifact uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.VERSION }}_src_with_deps.tar.xz path: ./sftpgo_${{ steps.get_version.outputs.VERSION }}_src_with_deps.tar.xz retention-days: 1 prepare-windows: name: Prepare Windows binaries runs-on: windows-2022 steps: - uses: actions/checkout@v6 - name: Set up Go uses: actions/setup-go@v6 with: go-version: ${{ env.GO_VERSION }} - name: Get SFTPGo version id: get_version run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT shell: bash - name: Build run: | $GIT_COMMIT = (git describe --always --abbrev=8 --dirty) | Out-String $DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String $FILE_VERSION = $Env:SFTPGO_VERSION.substring(1) + ".0" go install github.com/tc-hib/go-winres@latest go-winres simply --arch amd64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o sftpgo.exe mkdir arm64 $Env:CGO_ENABLED='0' $Env:GOOS='windows' $Env:GOARCH='arm64' go-winres simply --arch arm64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nosqlite -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe mkdir x86 $Env:GOARCH='386' go-winres simply --arch 386 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nosqlite -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\x86\sftpgo.exe Remove-Item Env:\CGO_ENABLED Remove-Item Env:\GOOS Remove-Item Env:\GOARCH env: SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }} - name: Initialize data provider run: ./sftpgo initprovider shell: bash - name: Prepare Release run: | mkdir output copy .\sftpgo.exe .\output copy .\sftpgo.json .\output copy .\sftpgo.db .\output copy .\LICENSE .\output\LICENSE.txt copy .\NOTICE .\output\NOTICE.txt mkdir output\templates xcopy .\templates .\output\templates\ /E mkdir output\static xcopy .\static .\output\static\ /E mkdir output\openapi xcopy .\openapi .\output\openapi\ /E iscc .\windows-installer\sftpgo.iss rm .\output\sftpgo.exe rm .\output\sftpgo.db copy .\arm64\sftpgo.exe .\output (Get-Content .\output\sftpgo.json).replace('"sqlite"', '"bolt"') | Set-Content .\output\sftpgo.json $Env:SFTPGO_DATA_PROVIDER__DRIVER='bolt' $Env:SFTPGO_DATA_PROVIDER__NAME='.\output\sftpgo.db' .\sftpgo.exe initprovider Remove-Item Env:\SFTPGO_DATA_PROVIDER__DRIVER Remove-Item Env:\SFTPGO_DATA_PROVIDER__NAME $Env:SFTPGO_ISS_ARCH='arm64' iscc .\windows-installer\sftpgo.iss rm .\output\sftpgo.exe copy .\x86\sftpgo.exe .\output $Env:SFTPGO_ISS_ARCH='x86' iscc .\windows-installer\sftpgo.iss env: SFTPGO_ISS_VERSION: ${{ steps.get_version.outputs.VERSION }} - name: Prepare Portable Release run: | mkdir win-portable copy .\sftpgo.exe .\win-portable mkdir win-portable\arm64 copy .\arm64\sftpgo.exe .\win-portable\arm64 mkdir win-portable\x86 copy .\x86\sftpgo.exe .\win-portable\x86 copy .\sftpgo.json .\win-portable (Get-Content .\win-portable\sftpgo.json).replace('"sqlite"', '"bolt"') | Set-Content .\win-portable\sftpgo.json copy .\output\sftpgo.db .\win-portable copy .\LICENSE .\win-portable\LICENSE.txt copy .\NOTICE .\win-portable\NOTICE.txt mkdir win-portable\templates xcopy .\templates .\win-portable\templates\ /E mkdir win-portable\static xcopy .\static .\win-portable\static\ /E mkdir win-portable\openapi xcopy .\openapi .\win-portable\openapi\ /E Compress-Archive .\win-portable\* sftpgo_portable.zip - name: Upload Windows installer x86_64 artifact uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.VERSION }}_windows_x86_64.exe path: ./sftpgo_windows_x86_64.exe retention-days: 1 - name: Upload Windows installer arm64 artifact uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.VERSION }}_windows_arm64.exe path: ./sftpgo_windows_arm64.exe retention-days: 1 - name: Upload Windows installer x86 artifact uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.VERSION }}_windows_x86.exe path: ./sftpgo_windows_x86.exe retention-days: 1 - name: Upload Windows portable artifact uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.VERSION }}_windows_portable.zip path: ./sftpgo_portable.zip retention-days: 1 prepare-mac: name: Prepare macOS binaries runs-on: macos-14 steps: - uses: actions/checkout@v6 - name: Set up Go uses: actions/setup-go@v6 with: go-version: ${{ env.GO_VERSION }} - name: Get SFTPGo version id: get_version run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT shell: bash - name: Build for macOS x86_64 run: go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo - name: Build for macOS arm64 run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64 - name: Initialize data provider run: ./sftpgo initprovider shell: bash - name: Prepare Release run: | mkdir -p output/{init,sqlite,bash_completion,zsh_completion} echo "For documentation please take a look here:" > output/README.txt echo "" >> output/README.txt echo "https://docs.sftpgo.com" >> output/README.txt cp LICENSE output/ cp NOTICE output/ cp sftpgo output/ cp sftpgo.json output/ cp sftpgo.db output/sqlite/ cp -r static output/ cp -r openapi output/ cp -r templates output/ cp init/com.github.drakkan.sftpgo.plist output/init/ ./sftpgo gen completion bash > output/bash_completion/sftpgo ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo ./sftpgo gen man -d output/man/man1 gzip output/man/man1/* cd output tar cJvf ../sftpgo_${SFTPGO_VERSION}_macOS_x86_64.tar.xz * cd .. cp sftpgo_arm64 output/sftpgo cd output tar cJvf ../sftpgo_${SFTPGO_VERSION}_macOS_arm64.tar.xz * cd .. env: SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }} - name: Upload macOS x86_64 artifact uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.VERSION }}_macOS_x86_64.tar.xz path: ./sftpgo_${{ steps.get_version.outputs.VERSION }}_macOS_x86_64.tar.xz retention-days: 1 - name: Upload macOS arm64 artifact uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.VERSION }}_macOS_arm64.tar.xz path: ./sftpgo_${{ steps.get_version.outputs.VERSION }}_macOS_arm64.tar.xz retention-days: 1 prepare-linux: name: Prepare Linux binaries runs-on: ubuntu-latest strategy: matrix: include: - arch: amd64 distro: ubuntu:18.04 go-arch: amd64 deb-arch: amd64 rpm-arch: x86_64 tar-arch: x86_64 - arch: aarch64 distro: ubuntu18.04 go-arch: arm64 deb-arch: arm64 rpm-arch: aarch64 tar-arch: arm64 - arch: ppc64le distro: ubuntu18.04 go-arch: ppc64le deb-arch: ppc64el rpm-arch: ppc64le tar-arch: ppc64le - arch: armv7 distro: ubuntu18.04 go-arch: arm7 deb-arch: armhf rpm-arch: armv7hl tar-arch: armv7 steps: - uses: actions/checkout@v6 - name: Get versions id: get_version run: | echo "SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT echo "GO_VERSION=${GO_VERSION}" >> $GITHUB_OUTPUT echo "COMMIT=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT shell: bash env: GO_VERSION: ${{ env.GO_VERSION }} - name: Build on amd64 if: ${{ matrix.arch == 'amd64' }} run: | echo '#!/bin/bash' > build.sh echo '' >> build.sh echo 'set -e' >> build.sh echo 'apt-get update -q -y' >> build.sh echo 'apt-get install -q -y curl gcc' >> build.sh echo 'curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/go${{ steps.get_version.outputs.GO_VERSION }}.linux-${{ matrix.go-arch }}.tar.gz' >> build.sh echo 'tar -C /usr/local -xzf go.tar.gz' >> build.sh echo 'export PATH=$PATH:/usr/local/go/bin' >> build.sh echo 'go version' >> build.sh echo 'cd /usr/local/src' >> build.sh echo 'go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_version.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' >> build.sh chmod 755 build.sh docker run --rm --name ubuntu-build --mount type=bind,source=`pwd`,target=/usr/local/src ${{ matrix.distro }} /usr/local/src/build.sh mkdir -p output/{init,sqlite,bash_completion,zsh_completion} echo "For documentation please take a look here:" > output/README.txt echo "" >> output/README.txt echo "https://github.com/drakkan/sftpgo/blob/${SFTPGO_VERSION}/README.md" >> output/README.txt cp LICENSE output/ cp NOTICE output/ cp sftpgo.json output/ cp -r templates output/ cp -r static output/ cp -r openapi output/ cp init/sftpgo.service output/init/ ./sftpgo initprovider ./sftpgo gen completion bash > output/bash_completion/sftpgo ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo ./sftpgo gen man -d output/man/man1 gzip output/man/man1/* cp sftpgo output/ cp sftpgo.db output/sqlite/ cd output tar cJvf sftpgo_${SFTPGO_VERSION}_linux_${{ matrix.tar-arch }}.tar.xz * cd .. env: SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} - uses: uraimo/run-on-arch-action@v3 if: ${{ matrix.arch != 'amd64' }} name: Build for ${{ matrix.arch }} id: build with: arch: ${{ matrix.arch }} distro: ${{ matrix.distro }} setup: | mkdir -p "${PWD}/output" dockerRunArgs: | --volume "${PWD}/output:/output" shell: /bin/bash install: | apt-get update -q -y apt-get install -q -y curl gcc xz-utils GO_DOWNLOAD_ARCH=${{ matrix.go-arch }} if [ ${{ matrix.arch}} == 'armv7' ] then GO_DOWNLOAD_ARCH=armv6l fi curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/go${{ steps.get_version.outputs.GO_VERSION }}.linux-${GO_DOWNLOAD_ARCH}.tar.gz tar -C /usr/local -xzf go.tar.gz run: | export PATH=$PATH:/usr/local/go/bin go version go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_version.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo mkdir -p output/{init,sqlite,bash_completion,zsh_completion} echo "For documentation please take a look here:" > output/README.txt echo "" >> output/README.txt echo "https://github.com/drakkan/sftpgo/blob/${{ steps.get_version.outputs.SFTPGO_VERSION }}/README.md" >> output/README.txt cp LICENSE output/ cp NOTICE output/ cp sftpgo.json output/ cp -r templates output/ cp -r static output/ cp -r openapi output/ cp init/sftpgo.service output/init/ ./sftpgo initprovider ./sftpgo gen completion bash > output/bash_completion/sftpgo ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo ./sftpgo gen man -d output/man/man1 gzip output/man/man1/* cp sftpgo output/ cp sftpgo.db output/sqlite/ cd output tar cJvf sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_${{ matrix.tar-arch }}.tar.xz * cd .. - name: Upload build artifact for ${{ matrix.arch }} uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_${{ matrix.tar-arch }}.tar.xz path: ./output/sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_${{ matrix.tar-arch }}.tar.xz retention-days: 1 - name: Build Packages id: build_linux_pkgs run: | export NFPM_ARCH=${{ matrix.go-arch }} cd pkgs ./build.sh PKG_VERSION=${SFTPGO_VERSION:1} echo "pkg-version=${PKG_VERSION}" >> $GITHUB_OUTPUT env: SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} - name: Upload Deb Package uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.build_linux_pkgs.outputs.pkg-version }}-1_${{ matrix.deb-arch}}.deb path: ./pkgs/dist/deb/sftpgo_${{ steps.build_linux_pkgs.outputs.pkg-version }}-1_${{ matrix.deb-arch}}.deb retention-days: 1 - name: Upload RPM Package uses: actions/upload-artifact@v7 with: name: sftpgo-${{ steps.build_linux_pkgs.outputs.pkg-version }}-1.${{ matrix.rpm-arch}}.rpm path: ./pkgs/dist/rpm/sftpgo-${{ steps.build_linux_pkgs.outputs.pkg-version }}-1.${{ matrix.rpm-arch}}.rpm retention-days: 1 prepare-linux-bundle: name: Prepare Linux bundle needs: prepare-linux runs-on: ubuntu-latest steps: - name: Get versions id: get_version run: | echo "SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT shell: bash - name: Download amd64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_x86_64.tar.xz - name: Download arm64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_arm64.tar.xz - name: Download ppc64le artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_ppc64le.tar.xz - name: Download armv7 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_armv7.tar.xz - name: Build bundle shell: bash run: | mkdir -p bundle/{arm64,ppc64le,armv7} cd bundle tar xvf ../sftpgo_${SFTPGO_VERSION}_linux_x86_64.tar.xz cd arm64 tar xvf ../../sftpgo_${SFTPGO_VERSION}_linux_arm64.tar.xz sftpgo cd ../ppc64le tar xvf ../../sftpgo_${SFTPGO_VERSION}_linux_ppc64le.tar.xz sftpgo cd ../armv7 tar xvf ../../sftpgo_${SFTPGO_VERSION}_linux_armv7.tar.xz sftpgo cd .. tar cJvf sftpgo_${SFTPGO_VERSION}_linux_bundle.tar.xz * cd .. env: SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} - name: Upload Linux bundle uses: actions/upload-artifact@v7 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_bundle.tar.xz path: ./bundle/sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_bundle.tar.xz retention-days: 1 create-release: name: Release needs: [prepare-linux-bundle, prepare-sources-with-deps, prepare-mac, prepare-windows] runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Get versions id: get_version run: | SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//} PKG_VERSION=${SFTPGO_VERSION:1} echo "SFTPGO_VERSION=${SFTPGO_VERSION}" >> $GITHUB_OUTPUT echo "PKG_VERSION=${PKG_VERSION}" >> $GITHUB_OUTPUT shell: bash - name: Download amd64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_x86_64.tar.xz - name: Download arm64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_arm64.tar.xz - name: Download ppc64le artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_ppc64le.tar.xz - name: Download armv7 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_armv7.tar.xz - name: Download Linux bundle artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_bundle.tar.xz - name: Download Deb amd64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.PKG_VERSION }}-1_amd64.deb - name: Download Deb arm64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.PKG_VERSION }}-1_arm64.deb - name: Download Deb ppc64le artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.PKG_VERSION }}-1_ppc64el.deb - name: Download Deb armv7 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.PKG_VERSION }}-1_armhf.deb - name: Download RPM x86_64 artifact uses: actions/download-artifact@v8 with: name: sftpgo-${{ steps.get_version.outputs.PKG_VERSION }}-1.x86_64.rpm - name: Download RPM aarch64 artifact uses: actions/download-artifact@v8 with: name: sftpgo-${{ steps.get_version.outputs.PKG_VERSION }}-1.aarch64.rpm - name: Download RPM ppc64le artifact uses: actions/download-artifact@v8 with: name: sftpgo-${{ steps.get_version.outputs.PKG_VERSION }}-1.ppc64le.rpm - name: Download RPM armv7 artifact uses: actions/download-artifact@v8 with: name: sftpgo-${{ steps.get_version.outputs.PKG_VERSION }}-1.armv7hl.rpm - name: Download macOS x86_64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_macOS_x86_64.tar.xz - name: Download macOS arm64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_macOS_arm64.tar.xz - name: Download Windows installer x86_64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_windows_x86_64.exe - name: Download Windows installer arm64 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_windows_arm64.exe - name: Download Windows installer x86 artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_windows_x86.exe - name: Download Windows portable artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_windows_portable.zip - name: Download source with deps artifact uses: actions/download-artifact@v8 with: name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_src_with_deps.tar.xz - name: Create release run: | mv sftpgo_windows_x86_64.exe sftpgo_${SFTPGO_VERSION}_windows_x86_64.exe mv sftpgo_windows_arm64.exe sftpgo_${SFTPGO_VERSION}_windows_arm64.exe mv sftpgo_windows_x86.exe sftpgo_${SFTPGO_VERSION}_windows_x86.exe mv sftpgo_portable.zip sftpgo_${SFTPGO_VERSION}_windows_portable.zip gh release create "${SFTPGO_VERSION}" -t "${SFTPGO_VERSION}" gh release upload "${SFTPGO_VERSION}" sftpgo_*.xz --clobber gh release upload "${SFTPGO_VERSION}" sftpgo-*.rpm --clobber gh release upload "${SFTPGO_VERSION}" sftpgo_*.deb --clobber gh release upload "${SFTPGO_VERSION}" sftpgo_*.exe --clobber gh release upload "${SFTPGO_VERSION}" sftpgo_*.zip --clobber gh release view "${SFTPGO_VERSION}" env: GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} ================================================ FILE: .gitignore ================================================ # compilation output sftpgo sftpgo.exe ================================================ FILE: .golangci.yml ================================================ version: "2" run: issues-exit-code: 1 tests: true linters: enable: - bodyclose - dogsled - dupl - goconst - gocyclo - misspell - revive - rowserrcheck - unconvert - unparam - whitespace settings: dupl: threshold: 150 errcheck: check-type-assertions: false check-blank: false goconst: min-len: 3 min-occurrences: 3 gocyclo: min-complexity: 15 # https://golangci-lint.run/usage/linters/#revive revive: rules: - name: var-naming severity: warning disabled: true exclude: [""] arguments: - ["ID"] # AllowList - ["VM"] # DenyList - - upper-case-const: true - - skip-package-name-checks: true exclusions: generated: lax presets: - common-false-positives - legacy - std-error-handling paths: - third_party$ - builtin$ - examples$ formatters: enable: - gofmt - goimports settings: gofmt: simplify: true goimports: local-prefixes: - github.com/drakkan/sftpgo exclusions: generated: lax paths: - third_party$ - builtin$ - examples$ ================================================ FILE: CODEOWNERS ================================================ * @drakkan ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at support@sftpgo.com. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: Dockerfile ================================================ FROM golang:1.26-trixie AS builder ENV GOFLAGS="-mod=readonly" RUN apt-get update && apt-get -y upgrade && rm -rf /var/lib/apt/lists/* RUN mkdir -p /workspace WORKDIR /workspace ARG GOPROXY COPY go.mod go.sum ./ RUN go mod download && go mod verify ARG COMMIT_SHA # This ARG allows to disable some optional features and it might be useful if you build the image yourself. # For example you can disable S3 and GCS support like this: # --build-arg FEATURES=nos3,nogcs ARG FEATURES COPY . . RUN set -xe && \ export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo # Set to "true" to download the "official" plugins in /usr/local/bin ARG DOWNLOAD_PLUGINS=false RUN if [ "${DOWNLOAD_PLUGINS}" = "true" ]; then apt-get update && apt-get install --no-install-recommends -y curl && ./docker/scripts/download-plugins.sh; fi FROM debian:trixie-slim # Set to "true" to install jq ARG INSTALL_OPTIONAL_PACKAGES=false RUN apt-get update && apt-get -y upgrade && apt-get install --no-install-recommends -y ca-certificates media-types && rm -rf /var/lib/apt/lists/* RUN if [ "${INSTALL_OPTIONAL_PACKAGES}" = "true" ]; then apt-get update && apt-get install --no-install-recommends -y jq && rm -rf /var/lib/apt/lists/*; fi RUN mkdir -p /etc/sftpgo /var/lib/sftpgo /usr/share/sftpgo /srv/sftpgo/data /srv/sftpgo/backups RUN groupadd --system -g 1000 sftpgo && \ useradd --system --gid sftpgo --no-create-home \ --home-dir /var/lib/sftpgo --shell /usr/sbin/nologin \ --comment "SFTPGo user" --uid 1000 sftpgo COPY --from=builder /workspace/sftpgo.json /etc/sftpgo/sftpgo.json COPY --from=builder /workspace/templates /usr/share/sftpgo/templates COPY --from=builder /workspace/static /usr/share/sftpgo/static COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi COPY --from=builder /workspace/sftpgo /usr/local/bin/sftpgo-plugin-* /usr/local/bin/ # Log to the stdout so the logs will be available using docker logs ENV SFTPGO_LOG_FILE_PATH="" # Modify the default configuration file RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' /etc/sftpgo/sftpgo.json && \ sed -i 's|"backups"|"/srv/sftpgo/backups"|' /etc/sftpgo/sftpgo.json RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups WORKDIR /var/lib/sftpgo USER 1000:1000 CMD ["sftpgo", "serve"] ================================================ FILE: Dockerfile.alpine ================================================ FROM golang:1.26-alpine3.23 AS builder ENV GOFLAGS="-mod=readonly" RUN apk -U upgrade --no-cache && apk add --update --no-cache bash ca-certificates curl git gcc g++ RUN mkdir -p /workspace WORKDIR /workspace ARG GOPROXY COPY go.mod go.sum ./ RUN go mod download && go mod verify ARG COMMIT_SHA # This ARG allows to disable some optional features and it might be useful if you build the image yourself. # For example you can disable S3 and GCS support like this: # --build-arg FEATURES=nos3,nogcs ARG FEATURES COPY . . RUN set -xe && \ export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo FROM alpine:3.23 # Set to "true" to install jq ARG INSTALL_OPTIONAL_PACKAGES=false RUN apk -U upgrade --no-cache && apk add --update --no-cache ca-certificates tzdata mailcap RUN if [ "${INSTALL_OPTIONAL_PACKAGES}" = "true" ]; then apk add --update --no-cache jq; fi RUN mkdir -p /etc/sftpgo /var/lib/sftpgo /usr/share/sftpgo /srv/sftpgo/data /srv/sftpgo/backups RUN addgroup -g 1000 -S sftpgo && \ adduser -u 1000 -h /var/lib/sftpgo -s /sbin/nologin -G sftpgo -S -D -H -g "SFTPGo user" sftpgo COPY --from=builder /workspace/sftpgo.json /etc/sftpgo/sftpgo.json COPY --from=builder /workspace/templates /usr/share/sftpgo/templates COPY --from=builder /workspace/static /usr/share/sftpgo/static COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi COPY --from=builder /workspace/sftpgo /usr/local/bin/ # Log to the stdout so the logs will be available using docker logs ENV SFTPGO_LOG_FILE_PATH="" # Modify the default configuration file RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' /etc/sftpgo/sftpgo.json && \ sed -i 's|"backups"|"/srv/sftpgo/backups"|' /etc/sftpgo/sftpgo.json RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups WORKDIR /var/lib/sftpgo USER 1000:1000 CMD ["sftpgo", "serve"] ================================================ FILE: Dockerfile.distroless ================================================ FROM golang:1.26-trixie AS builder ENV CGO_ENABLED=0 GOFLAGS="-mod=readonly" RUN apt-get update && apt-get -y upgrade && apt-get install --no-install-recommends -y media-types && rm -rf /var/lib/apt/lists/* RUN mkdir -p /workspace WORKDIR /workspace ARG GOPROXY COPY go.mod go.sum ./ RUN go mod download && go mod verify ARG COMMIT_SHA # This ARG allows to disable some optional features and it might be useful if you build the image yourself. # For this variant we disable SQLite support since it requires CGO and so a C runtime which is not installed # in distroless/static-* images ARG FEATURES COPY . . RUN set -xe && \ export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo # Modify the default configuration file RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' sftpgo.json && \ sed -i 's|"backups"|"/srv/sftpgo/backups"|' sftpgo.json && \ sed -i 's|"sqlite"|"bolt"|' sftpgo.json RUN mkdir /etc/sftpgo /var/lib/sftpgo /srv/sftpgo FROM gcr.io/distroless/static-debian13 COPY --from=builder --chown=1000:1000 /etc/sftpgo /etc/sftpgo COPY --from=builder --chown=1000:1000 /srv/sftpgo /srv/sftpgo COPY --from=builder --chown=1000:1000 /var/lib/sftpgo /var/lib/sftpgo COPY --from=builder --chown=1000:1000 /workspace/sftpgo.json /etc/sftpgo/sftpgo.json COPY --from=builder /workspace/templates /usr/share/sftpgo/templates COPY --from=builder /workspace/static /usr/share/sftpgo/static COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi COPY --from=builder /workspace/sftpgo /usr/local/bin/ COPY --from=builder /etc/mime.types /etc/mime.types # Log to the stdout so the logs will be available using docker logs ENV SFTPGO_LOG_FILE_PATH="" # These env vars are required to avoid the following error when calling user.Current(): # unable to get the current user: user: Current requires cgo or $USER set in environment ENV USER=sftpgo ENV HOME=/var/lib/sftpgo WORKDIR /var/lib/sftpgo USER 1000:1000 CMD ["sftpgo", "serve"] ================================================ FILE: LICENSE ================================================ GNU AFFERO GENERAL PUBLIC LICENSE Version 3, 19 November 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU Affero General Public License is a free, copyleft license for software and other kinds of works, specifically designed to ensure cooperation with the community in the case of network server software. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, our General Public Licenses are intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. Developers that use our General Public Licenses protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License which gives you legal permission to copy, distribute and/or modify the software. A secondary benefit of defending all users' freedom is that improvements made in alternate versions of the program, if they receive widespread use, become available for other developers to incorporate. Many developers of free software are heartened and encouraged by the resulting cooperation. However, in the case of software used on network servers, this result may fail to come about. The GNU General Public License permits making a modified version and letting the public access it on a server without ever releasing its source code to the public. The GNU Affero General Public License is designed specifically to ensure that, in such cases, the modified source code becomes available to the community. It requires the operator of a network server to provide the source code of the modified version running there to the users of that server. Therefore, public use of a modified version, on a publicly accessible server, gives the public access to the source code of the modified version. An older license, called the Affero General Public License and published by Affero, was designed to accomplish similar goals. This is a different license, not a version of the Affero GPL, but Affero has released a new version of the Affero GPL which permits relicensing under this license. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU Affero General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Remote Network Interaction; Use with the GNU General Public License. Notwithstanding any other provision of this License, if you modify the Program, your modified version must prominently offer all users interacting with it remotely through a computer network (if your version supports such interaction) an opportunity to receive the Corresponding Source of your version by providing access to the Corresponding Source from a network server at no charge, through some standard or customary means of facilitating copying of software. This Corresponding Source shall include the Corresponding Source for any work covered by version 3 of the GNU General Public License that is incorporated pursuant to the following paragraph. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the work with which it is combined will remain governed by version 3 of the GNU General Public License. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU Affero General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU Affero General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU Affero General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU Affero General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If your software can interact with users remotely through a computer network, you should also make sure that it provides a way for users to get its source. For example, if your program is a web application, its interface could display a "Source" link that leads users to an archive of the code. There are many ways you could offer source, and different solutions will be better for different programs; see section 13 for the specific requirements. You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU AGPL, see . ================================================ FILE: NOTICE ================================================ Additional terms under GNU AGPL version 3 section 7.3(b) and 13.1: If you have included SFTPGo so that it is offered through any network interactions, including by means of an external user interface, or any other integration, even without modifying its source code and then SFTPGo is partially, fully or optionally configured via your frontend, you must provide reasonable but clear attribution to the SFTPGo project and its author(s), not imply any endorsement by or affiliation with the SFTPGo project, and you must prominently offer all users interacting with it remotely through a computer network an opportunity to receive the Corresponding Source of the SFTPGo version you include by providing a link to the Corresponding Source in the SFTPGo source code repository. ================================================ FILE: README.md ================================================ # SFTPGo [![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg)](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg) [![License: AGPL-3.0-only](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0) [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) Full-featured and highly configurable event-driven file transfer solution. Server protocols: SFTP, HTTP/S, FTP/S, WebDAV. Storage backends: local filesystem, encrypted local filesystem, S3 (compatible) Object Storage, Google Cloud Storage, Azure Blob Storage, other SFTP servers. With SFTPGo you can leverage local and cloud storage backends for exchanging and storing files internally or with business partners using the same tools and processes you are already familiar with. ## Project Status & Editions SFTPGo is an open-source project with a sustainable business model. We offer two editions to suit different requirements, ensuring the project remains healthy and maintained for everyone. ### Open Source (Community) Free, Copyleft (AGPLv3), Community Supported. The Community edition is a fully functional, production-ready solution widely adopted worldwide. It includes all the core protocols, storage backends, and the WebAdmin/WebClient UIs. It is ideal for: - Standard file transfer needs. - Integrating storage backends (S3, GCS, Azure Blob) with legacy protocols. - Projects that are comfortable with AGPLv3 licensing. ### SFTPGo Enterprise Commercial License, Professional Support, ISO 27001 Vendor. The Enterprise edition is built on the same core but extends it for mission-critical environments, compliance-heavy industries, and advanced workflows. It is a drop-in replacement (seamless upgrade). | Feature | Open Source (Community) | Enterprise Edition | | :--- | :--- | :--- | | **License Type** | AGPLv3 (Copyleft) | **Commercial License**
Proprietary/No Copyleft | | **Vendor Compliance** | Not Applicable
Community Project | **Certified Vendor**
ISO 27001 & Supply Chain Validation | | **Support** | Community (GitHub) | **Direct from Authors** | | **Cloud Storage Engine** | Standard | **High Performance & Scalable**
In-memory streaming (no local temp files) and up to 70% faster | | **High Availability (HA)** | Standard
Shared DB & Storage | **Advanced**
Enhanced event handling and optimized instance coordination | | **Automation Logic** | Simple Placeholders | **Dynamic Logic & Virtual Folders**
Conditions, loops, route data across storage backends | | **Data Lifecycle** | Delete / Retain | **Smart Archiving**
Move data to external Cloud/SFTP storage via Virtual Folders | | **Email Data Ingestion** | - | **Native IMAP Integration**
Auto-extract attachments from email to storage | | **Public Sharing** | Standard Links | **Advanced & Collaborative**
Email Authentication & Group Delegation | | **Data Protection** | - | **Encryption & Scanning**
Automated PGP, Antivirus & DLP via ICAP | | **Advanced Identity (SSO)** | Standard | **Extended Controls**
Advanced Single Sign-On parameters | | **Document Editing** | - | **Included**
View, edit, and co-author in browser | **Note**: We are committed to keeping the Open Source edition powerful and maintained. The Enterprise edition helps fund the development of the entire SFTPGo ecosystem. ## Sponsors If you rely on SFTPGo in your projects, consider becoming a [sponsor](https://github.com/sponsors/drakkan). Your sponsorship helps cover maintenance, security updates and ongoing development of the open-source edition. ### Thank you to our sponsors #### Platinum sponsors [Aledade logo](https://www.aledade.com/)

[Jump Trading logo](https://www.jumptrading.com/)

[WP Engine logo](https://wpengine.com/) #### Silver sponsors [IDCS logo](https://idcs.ip-paris.fr/) #### Bronze sponsors [7digital logo](https://www.7digital.com/)

[servinga logo](https://servinga.com/)

[ReUI logo](https://www.reui.io/) ## Documentation You can explore all supported features and configuration options at [docs.sftpgo.com](https://docs.sftpgo.com/latest/). **Note:** The link above refers to the **Community Edition**. For details on **Enterprise Edition**, please refer to the [Enterprise Documentation](https://docs.sftpgo.com/enterprise/). ## Support - **Community Support**: use [GitHub Discussions](https://github.com/drakkan/sftpgo/discussions) to ask questions, share feedback, and engage with other users. - **Commercial Support**: If you require guaranteed SLAs, expert guidance, or the advanced features listed above, check out [SFTPGo Enterprise](https://sftpgo.com). SFTPGo Enterprise is available as: - On-premises: Full control on your infrastructure. More details: [sftpgo.com/on-premises](https://sftpgo.com/on-premises) - Fully managed SaaS: We handle the infrastructure. More details: [sftpgo.com/saas](https://sftpgo.com/saas) ## Internationalization The translations are available via [Crowdin](https://crowdin.com/project/sftpgo), who have granted us an open source license. Before translating please take a look at our contribution [guidelines](https://docs.sftpgo.com/latest/web-interfaces/#internationalization). ## Release Cadence SFTPGo follows a feature-driven release cycle. - Enterprise Edition: Receives major new features first and follows a faster [release cadence](https://docs.sftpgo.com/enterprise/changelog/). - Community Edition: Remains maintained, receiving bug fixes, security updates, and updates to core features. ## Acknowledgements SFTPGo makes use of the third party libraries listed inside [go.mod](./go.mod). We are very grateful to all the people who contributed with ideas and/or pull requests. Thank you to [ysura](https://www.ysura.com/) for granting us stable access to a test AWS S3 account. Thank you to [KeenThemes](https://keenthemes.com/) for granting us a custom license to use their amazing [themes](https://keenthemes.com/bootstrap-templates) for the SFTPGo WebAdmin and WebClient user interfaces, across both the Open Source and Open Core versions. Thank you to [Crowdin](https://crowdin.com/) for granting us an Open Source License. Thank you to [Incode](https://www.incode.it/) for helping us to improve the UI/UX. ## License SFTPGo source code is licensed under the GNU AGPL-3.0-only with [additional terms](./NOTICE). The [theme](https://keenthemes.com/bootstrap-templates) used in WebAdmin and WebClient user interfaces is proprietary, this means: - KeenThemes HTML/CSS/JS components are allowed for use only within the SFTPGo product and restricted to be used in a resealable HTML template that can compete with KeenThemes products anyhow. - The SFTPGo WebAdmin and WebClient user interfaces (HTML, CSS and JS components) based on this theme are allowed for use only within the SFTPGo product and therefore cannot be used in derivative works/products without an explicit grant from the [SFTPGo Team](mailto:support@sftpgo.com). More information about [compliance](https://sftpgo.com/compliance.html). **Note:** We do not provide legal advice. If you have questions about license compliance or whether your use case is permitted under the license terms, please consult your legal team. ## Copyright Copyright (C) 2019 - 2026 Nicola Murino ================================================ FILE: SECURITY.md ================================================ # Security Policy ## Supported Versions We actively maintain the latest stable release of SFTPGo. While we strive to keep the Open Source version secure and up-to-date, maintenance is performed on a best-effort basis by the community and contributors. ## Scope and Dependency Policy Our security advisories focus on vulnerabilities found within the **SFTPGo codebase itself**. To ensure the long-term sustainability of the project, we handle upstream dependencies (like the Go standard library, external packages, or Docker base images) as follows: - Community Updates: For the Open Source version, vulnerabilities in upstream components (such as the Go standard library or third-party packages) are addressed during our **regular release cycles**. We generally do not provide immediate, out-of-band or ad-hoc releases to address dependency-only CVEs. - Empowering Users: One of the strengths of SFTPGo being open-source is that you have full control. If your security scanners require an immediate fix, you can always rebuild the project using the latest patched Go toolchain or updated dependencies. - Compatibility: We are committed to keeping SFTPGo compatible with the latest stable Go compiler. If an upstream fix breaks SFTPGo, fixing that becomes a priority for us. - Professional Needs: We understand that some organizations have strict compliance requirements or internal SLAs that require guaranteed, immediate response times and out-of-band patches. For these cases, we offer [SFTPGo Enterprise](https://sftpgo.com/on-premises) to cover the additional maintenance and support overhead. ## Reporting a Vulnerability To report (possible) security issues in SFTPGo, please either send a mail to the [SFTPGo Team](mailto:support@sftpgo.com) or use Github's [private reporting feature](https://github.com/drakkan/sftpgo/security/advisories/new). ================================================ FILE: crowdin.yml ================================================ project_id_env: CROWDIN_PROJECT_ID api_token_env: CROWDIN_PERSONAL_TOKEN files: - source: /static/locales/en/translation.json translation: /static/locales/%two_letters_code%/%original_file_name% type: i18next_json ================================================ FILE: docker/scripts/download-plugins.sh ================================================ #!/usr/bin/env bash set -euo pipefail ARCH=$(uname -m) case ${ARCH} in x86_64) SUFFIX=amd64 ;; aarch64) SUFFIX=arm64 ;; *) SUFFIX=ppc64le ;; esac echo "Downloading plugins for arch ${SUFFIX}" PLUGINS=(geoipfilter kms pubsub eventstore eventsearch auth) for PLUGIN in "${PLUGINS[@]}"; do URL="https://github.com/sftpgo/sftpgo-plugin-${PLUGIN}/releases/latest/download/sftpgo-plugin-${PLUGIN}-linux-${SUFFIX}" DEST="/usr/local/bin/sftpgo-plugin-${PLUGIN}" echo "Downloading ${PLUGIN}..." if curl --fail --silent --show-error -L "${URL}" --output "${DEST}"; then chmod 755 "${DEST}" else echo "Error: Failed to download ${PLUGIN}" >&2 exit 1 fi done echo "All plugins downloaded successfully" ================================================ FILE: examples/OTP/authy/README.md ================================================ # Authy These example show how-to integrate [Twilio Authy API](https://www.twilio.com/docs/authy/api) for One-Time-Password logins. The examples assume that the user has the free [Authy app](https://authy.com/) installed and uses it to generate offline [TOTP](https://en.wikipedia.org/wiki/Time-based_One-time_Password_algorithm) codes (soft tokens). You first need to [create an Authy Application in the Twilio Console](https://twilio.com/console/authy/applications?_ga=2.205553366.451688189.1597667213-1526360003.1597667213), then you can create a new Authy user and store a reference to the matching SFTPGo account. Verify that your Authy application is successfully registered: ```bash export AUTHY_API_KEY= curl 'https://api.authy.com/protected/json/app/details' -H "X-Authy-API-Key: $AUTHY_API_KEY" ``` now create an Authy user: ```bash curl -XPOST "https://api.authy.com/protected/json/users/new" \ -H "X-Authy-API-Key: $AUTHY_API_KEY" \ --data-urlencode user[email]="user@domain.com" \ --data-urlencode user[cellphone]="317-338-9302" \ --data-urlencode user[country_code]="54" ``` The response is something like this: ```json {"message":"User created successfully.","user":{"id":xxxxxxxx},"success":true} ``` Save the user id somewhere and add a reference to the matching SFTPGo account. You could also store this ID in the `additional_info` SFTPGo user field. After this step you can use the Authy app installed on your phone to generate TOTP codes. Now you can verify the token using an HTTP GET request: ```bash export TOKEN= export AUTHY_ID= curl -i "https://api.authy.com/protected/json/verify/${TOKEN}/${AUTHY_ID}" \ -H "X-Authy-API-Key: $AUTHY_API_KEY" ``` So inside your hook you need to check: - the HTTP response code for the verify request, it must be `200` - the JSON response body, it must contains the key `success` with the value `true` (as string) If these conditions are met the token is valid and you allow the user to login. We provide the following examples: - [Keyboard interactive authentication](./keyint/README.md) for 2FA using password + Authy one time token. - [External authentication](./extauth/README.md) using Authy one time tokens as passwords. - [Check password hook](./checkpwd/README.md) for 2FA using a password consisting of a fixed string and a One Time Token. Please note that these are sample programs not intended for production use, you should write your own hook based on them and you should prefer HTTP based hooks if performance is a concern. :warning: SFTPGo has also built-in 2FA support. ================================================ FILE: examples/OTP/authy/checkpwd/README.md ================================================ # Authy 2FA via check password hook This example shows how to use 2FA via the check password hook using a password consisting of a fixed part and an Authy TOTP token. The hook will check the TOTP token using the Authy API and SFTPGo will check the fixed part. Please read the [sample code](./main.go), it should be self explanatory. ================================================ FILE: examples/OTP/authy/checkpwd/go.mod ================================================ module github.com/drakkan/sftpgo/authy/checkpwd go 1.22.2 ================================================ FILE: examples/OTP/authy/checkpwd/main.go ================================================ package main import ( "encoding/json" "fmt" "io" "log" "net/http" "os" "time" ) type userMapping struct { SFTPGoUsername string AuthyID int64 AuthyAPIKey string } type checkPasswordResponse struct { // 0 KO, 1 OK, 2 partial success Status int `json:"status"` // for status == 2 this is the password that SFTPGo will check against the one stored // inside the data provider ToVerify string `json:"to_verify"` } var ( mapping []userMapping ) func init() { // this is for demo only, you probably want to get this mapping dynamically, for example using a database query mapping = append(mapping, userMapping{ SFTPGoUsername: "", AuthyID: 1234567, AuthyAPIKey: "", }) } func printResponse(status int, toVerify string) { r := checkPasswordResponse{ Status: status, ToVerify: toVerify, } resp, _ := json.Marshal(r) fmt.Printf("%v\n", string(resp)) if status > 0 { os.Exit(0) } else { os.Exit(1) } } func main() { // get credentials from env vars username := os.Getenv("SFTPGO_AUTHD_USERNAME") password := os.Getenv("SFTPGO_AUTHD_PASSWORD") for _, m := range mapping { if m.SFTPGoUsername == username { // Authy token len is 7, we assume that we have the password followed by the token pwdLen := len(password) if pwdLen <= 7 { printResponse(0, "") } pwd := password[:pwdLen-7] authyToken := password[pwdLen-7:] // now verify the authy token and instruct SFTPGo to check the password if the token is OK url := fmt.Sprintf("https://api.authy.com/protected/json/verify/%v/%v", authyToken, m.AuthyID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { log.Fatal(err) } req.Header.Set("X-Authy-API-Key", m.AuthyAPIKey) httpClient := &http.Client{ Timeout: 10 * time.Second, } resp, err := httpClient.Do(req) if err != nil { printResponse(0, "") } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { // status code 200 is expected printResponse(0, "") } var authyResponse map[string]interface{} respBody, err := io.ReadAll(resp.Body) if err != nil { printResponse(0, "") } err = json.Unmarshal(respBody, &authyResponse) if err != nil { printResponse(0, "") } if authyResponse["success"].(string) == "true" { printResponse(2, pwd) } printResponse(0, "") break } } // no mapping found printResponse(0, "") } ================================================ FILE: examples/OTP/authy/extauth/README.md ================================================ # Authy external authentication This example shows how to use Authy TOTP token as password for SFTPGo users. Please read the [sample code](./main.go), it should be self explanatory. ================================================ FILE: examples/OTP/authy/extauth/go.mod ================================================ module github.com/drakkan/sftpgo/authy/extauth go 1.22.2 ================================================ FILE: examples/OTP/authy/extauth/main.go ================================================ package main import ( "encoding/json" "fmt" "io" "log" "net/http" "os" "path/filepath" "time" ) type userMapping struct { SFTPGoUsername string AuthyID int64 AuthyAPIKey string } // we assume that the SFTPGo already exists, we only check the one time token. // If you need to create the SFTPGo user more fields are needed here type minimalSFTPGoUser struct { Status int `json:"status,omitempty"` Username string `json:"username"` HomeDir string `json:"home_dir,omitempty"` Permissions map[string][]string `json:"permissions"` } var ( mapping []userMapping ) func init() { // this is for demo only, you probably want to get this mapping dynamically, for example using a database query mapping = append(mapping, userMapping{ SFTPGoUsername: "", AuthyID: 1234567, AuthyAPIKey: "", }) } func printResponse(username string) { u := minimalSFTPGoUser{ Username: username, Status: 1, HomeDir: filepath.Join(os.TempDir(), username), } u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{"*"} resp, _ := json.Marshal(u) fmt.Printf("%v\n", string(resp)) if len(username) > 0 { os.Exit(0) } else { os.Exit(1) } } func main() { // get credentials from env vars username := os.Getenv("SFTPGO_AUTHD_USERNAME") password := os.Getenv("SFTPGO_AUTHD_PASSWORD") if len(password) == 0 { // login method is not password printResponse("") return } for _, m := range mapping { if m.SFTPGoUsername == username { // mapping found we can now verify the token url := fmt.Sprintf("https://api.authy.com/protected/json/verify/%v/%v", password, m.AuthyID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { log.Fatal(err) } req.Header.Set("X-Authy-API-Key", m.AuthyAPIKey) httpClient := &http.Client{ Timeout: 10 * time.Second, } resp, err := httpClient.Do(req) if err != nil { printResponse("") } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { // status code 200 is expected printResponse("") } var authyResponse map[string]interface{} respBody, err := io.ReadAll(resp.Body) if err != nil { printResponse("") } err = json.Unmarshal(respBody, &authyResponse) if err != nil { printResponse("") } if authyResponse["success"].(string) == "true" { printResponse(username) } printResponse("") break } } // no mapping found printResponse("") } ================================================ FILE: examples/OTP/authy/keyint/README.md ================================================ # Authy 2FA using keyboard interactive authentication This example shows how to authenticate SFTP users using 2FA (password + Authy token). Please read the [sample code](./main.go), it should be self explanatory. ================================================ FILE: examples/OTP/authy/keyint/go.mod ================================================ module github.com/drakkan/sftpgo/authy/keyint go 1.22.2 ================================================ FILE: examples/OTP/authy/keyint/main.go ================================================ package main import ( "bufio" "encoding/json" "fmt" "io" "net/http" "os" "time" ) type userMapping struct { SFTPGoUsername string AuthyID int64 AuthyAPIKey string } type keyboardAuthHookResponse struct { Instruction string `json:"instruction,omitempty"` Questions []string `json:"questions,omitempty"` Echos []bool `json:"echos,omitempty"` AuthResult int `json:"auth_result"` CheckPwd int `json:"check_password,omitempty"` } var ( mapping []userMapping ) func init() { // this is for demo only, you probably want to get this mapping dynamically, for example using a database query mapping = append(mapping, userMapping{ SFTPGoUsername: "", AuthyID: 1234567, AuthyAPIKey: "", }) } func printAuthResponse(result int) { resp, _ := json.Marshal(keyboardAuthHookResponse{ AuthResult: result, }) fmt.Printf("%v\n", string(resp)) if result == 1 { os.Exit(0) } else { os.Exit(1) } } func main() { // get credentials from env vars username := os.Getenv("SFTPGO_AUTHD_USERNAME") var userMap userMapping for _, m := range mapping { if m.SFTPGoUsername == username { userMap = m break } } if userMap.SFTPGoUsername != username { // no mapping found os.Exit(1) } checkPwdQuestion := keyboardAuthHookResponse{ Instruction: "This is a sample keyboard authentication program that ask for your password + Authy token", Questions: []string{"Your password: "}, Echos: []bool{false}, CheckPwd: 1, AuthResult: 0, } q, _ := json.Marshal(checkPwdQuestion) fmt.Printf("%v\n", string(q)) // in a real world app you probably want to use a read timeout scanner := bufio.NewScanner(os.Stdin) scanner.Scan() if scanner.Err() != nil { printAuthResponse(-1) } response := scanner.Text() if response != "OK" { printAuthResponse(-1) } checkTokenQuestion := keyboardAuthHookResponse{ Instruction: "", Questions: []string{"Authy token: "}, Echos: []bool{false}, CheckPwd: 0, AuthResult: 0, } q, _ = json.Marshal(checkTokenQuestion) fmt.Printf("%v\n", string(q)) scanner.Scan() if scanner.Err() != nil { printAuthResponse(-1) } authyToken := scanner.Text() url := fmt.Sprintf("https://api.authy.com/protected/json/verify/%v/%v", authyToken, userMap.AuthyID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { printAuthResponse(-1) } req.Header.Set("X-Authy-API-Key", userMap.AuthyAPIKey) httpClient := &http.Client{ Timeout: 10 * time.Second, } resp, err := httpClient.Do(req) if err != nil { printAuthResponse(-1) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { // status code 200 is expected printAuthResponse(-1) } var authyResponse map[string]interface{} respBody, err := io.ReadAll(resp.Body) if err != nil { printAuthResponse(-1) } err = json.Unmarshal(respBody, &authyResponse) if err != nil { printAuthResponse(-1) } if authyResponse["success"].(string) == "true" { printAuthResponse(1) } printAuthResponse(-1) } ================================================ FILE: examples/backup/README.md ================================================ # Data Backup :warning: Since v2.4.0 you can use the [EventManager](https://docs.sftpgo.com/latest/eventmanager/) to schedule backups. The `backup` example script shows how to use the SFTPGo REST API to backup your data. The script is written in Python and has the following requirements: - python3 or python2 - python [Requests](https://requests.readthedocs.io/en/master/) module The provided example tries to connect to an SFTPGo instance running on `127.0.0.1:8080` using the following credentials: - username: `admin` - password: `password` and, if you execute it daily, it saves a different backup file for each day of the week. The backups will be saved within the configured `backups_path`. Please edit the script according to your needs. ================================================ FILE: examples/backup/backup ================================================ #!/usr/bin/env python from datetime import datetime import sys import requests try: import urllib.parse as urlparse except ImportError: import urlparse # change base_url to point to your SFTPGo installation base_url = "http://127.0.0.1:8080" # set to False if you want to skip TLS certificate validation verify_tls_cert = True # set the credentials for a valid admin here admin_user = "admin" admin_password = "password" # get a JWT token auth = requests.auth.HTTPBasicAuth(admin_user, admin_password) r = requests.get(urlparse.urljoin(base_url, "api/v2/token"), auth=auth, verify=verify_tls_cert, timeout=10) if r.status_code != 200: print("error getting access token: {}".format(r.text)) sys.exit(1) access_token = r.json()["access_token"] auth_header = {"Authorization": "Bearer " + access_token} r = requests.get(urlparse.urljoin(base_url, "api/v2/dumpdata"), params={"output-file":"backup_{}.json".format(datetime.today().strftime('%w'))}, headers=auth_header, verify=verify_tls_cert, timeout=10) if r.status_code == 200: print("backup OK") else: print("backup error, status {}, response: {}".format(r.status_code, r.text)) ================================================ FILE: examples/bulkupdate/README.md ================================================ # Bulk user update The `bulkuserupdate` example script shows how to use the SFTPGo REST API to easily update some common parameters for multiple users while preserving the others. The script is written in Python and has the following requirements: - python3 or python2 - python [Requests](https://requests.readthedocs.io/en/master/) module The provided example tries to connect to an SFTPGo instance running on `127.0.0.1:8080` using the following credentials: - username: `admin` - password: `password` and it updates some fields for `user1`, `user2` and `user3`. Please edit the script according to your needs. ================================================ FILE: examples/bulkupdate/bulkuserupdate ================================================ #!/usr/bin/env python import posixpath import sys import requests try: import urllib.parse as urlparse except ImportError: import urlparse # change base_url to point to your SFTPGo installation base_url = "http://127.0.0.1:8080" # set to False if you want to skip TLS certificate validation verify_tls_cert = True # set the credentials for a valid admin here admin_user = "admin" admin_password = "password" # insert here the users you want to update users_to_update = ["user1", "user2", "user3"] # set here the fields you need to update fields_to_update = {"status":0, "quota_files": 1000, "additional_info":"updated using the bulkuserupdate example script"} # get a JWT token auth = requests.auth.HTTPBasicAuth(admin_user, admin_password) r = requests.get(urlparse.urljoin(base_url, "api/v2/token"), auth=auth, verify=verify_tls_cert, timeout=10) if r.status_code != 200: print("error getting access token: {}".format(r.text)) sys.exit(1) access_token = r.json()["access_token"] auth_header = {"Authorization": "Bearer " + access_token} for username in users_to_update: r = requests.get(urlparse.urljoin(base_url, posixpath.join("api/v2/users", username)), headers=auth_header, verify=verify_tls_cert, timeout=10) if r.status_code != 200: print("error getting user {}: {}".format(username, r.text)) continue user = r.json() user.update(fields_to_update) r = requests.put(urlparse.urljoin(base_url, posixpath.join("api/v2/users", username)), headers=auth_header, verify=verify_tls_cert, json=user, timeout=10) if r.status_code == 200: print("user {} updated".format(username)) else: print("error updating user {}, response code: {} response text: {}".format(username, r.status_code, r.text)) ================================================ FILE: examples/convertusers/README.md ================================================ # Import users from other stores `convertusers` is a very simple command line client, written in python, to import users from other stores. It requires `python3` or `python2`. Here is the usage: ```console usage: convertusers [-h] [--min-uid MIN_UID] [--max-uid MAX_UID] [--usernames USERNAMES [USERNAMES ...]] [--force-uid FORCE_UID] [--force-gid FORCE_GID] input_file {unix-passwd,pure-ftpd,proftpd} output_file Convert users to a JSON format suitable to use with loadddata positional arguments: input_file {unix-passwd,pure-ftpd,proftpd} To import from unix-passwd format you need the permission to read /etc/shadow that is typically granted to the root user only output_file optional arguments: -h, --help show this help message and exit --min-uid MIN_UID if >= 0 only import users with UID greater or equal to this value. Default: -1 --max-uid MAX_UID if >= 0 only import users with UID lesser or equal to this value. Default: -1 --usernames USERNAMES [USERNAMES ...] Only import users with these usernames. Default: [] --force-uid FORCE_UID if >= 0 the imported users will have this UID in SFTPGo. Default: -1 --force-gid FORCE_GID if >= 0 the imported users will have this GID in SFTPGo. Default: -1 ``` Let's see some examples: ```console python convertusers "" unix-passwd unix_users.json --min-uid 500 --force-uid 1000 --force-gid 1000 ``` ```console python convertusers pureftpd.passwd pure-ftpd pure_users.json --usernames "user1" "user2" ``` ```console python convertusers proftpd.passwd proftpd pro_users.json ``` The generated json file can be used as input for the `loaddata` REST API. Please note that when importing Linux/Unix users the input file is not required: `/etc/passwd` and `/etc/shadow` are automatically parsed. `/etc/shadow` read permission is typically granted to the `root` user only, so you need to execute `convertusers` as `root`. :warning: SFTPGo does not currently support `yescrypt` hashed passwords. ================================================ FILE: examples/convertusers/convertusers ================================================ #!/usr/bin/env python import argparse import json import sys import time try: import pwd import spwd except ImportError: pwd = None class ConvertUsers: def __init__(self, input_file, users_format, output_file, min_uid, max_uid, usernames, force_uid, force_gid): self.input_file = input_file self.users_format = users_format self.output_file = output_file self.min_uid = min_uid self.max_uid = max_uid self.usernames = usernames self.force_uid = force_uid self.force_gid = force_gid self.SFTPGoUsers = [] def buildUserObject(self, username, password, home_dir, uid, gid, max_sessions, quota_size, quota_files, upload_bandwidth, download_bandwidth, status, expiration_date, allowed_ip=[], denied_ip=[]): return {'id':0, 'username':username, 'password':password, 'home_dir':home_dir, 'uid':uid, 'gid':gid, 'max_sessions':max_sessions, 'quota_size':quota_size, 'quota_files':quota_files, 'permissions':{'/':["*"]}, 'upload_bandwidth':upload_bandwidth, 'download_bandwidth':download_bandwidth, 'status':status, 'expiration_date':expiration_date, 'filters':{'allowed_ip':allowed_ip, 'denied_ip':denied_ip}} def addUser(self, user): user['id'] = len(self.SFTPGoUsers) + 1 print('') print('New user imported: {}'.format(user)) print('') self.SFTPGoUsers.append(user) def saveUsers(self): if self.SFTPGoUsers: data = {'users':self.SFTPGoUsers} jsonData = json.dumps(data) with open(self.output_file, 'w') as f: f.write(jsonData) print() print('Number of users saved to "{}": {}. You can import them using loaddata'.format(self.output_file, len(self.SFTPGoUsers))) print() sys.exit(0) else: print('No user imported') sys.exit(1) def convert(self): if self.users_format == 'unix-passwd': self.convertFromUnixPasswd() elif self.users_format == 'pure-ftpd': self.convertFromPureFTPD() else: self.convertFromProFTPD() self.saveUsers() def isUserValid(self, username, uid): if self.usernames and not username in self.usernames: return False if self.min_uid >= 0 and uid < self.min_uid: return False if self.max_uid >= 0 and uid > self.max_uid: return False return True def convertFromUnixPasswd(self): days_from_epoch_time = time.time() / 86400 for user in pwd.getpwall(): username = user.pw_name password = user.pw_passwd uid = user.pw_uid gid = user.pw_gid home_dir = user.pw_dir status = 1 expiration_date = 0 if not self.isUserValid(username, uid): continue if self.force_uid >= 0: uid = self.force_uid if self.force_gid >= 0: gid = self.force_gid # FIXME: if the passwords aren't in /etc/shadow they are probably DES encrypted and we don't support them if password == 'x' or password == '*': user_info = spwd.getspnam(username) password = user_info.sp_pwdp if not password or password == '!!' or password == '!*': print('cannot import user "{}" without a password'.format(username)) continue if user_info.sp_inact > 0: last_pwd_change_diff = days_from_epoch_time - user_info.sp_lstchg if last_pwd_change_diff > user_info.sp_inact: status = 0 if user_info.sp_expire > 0: expiration_date = user_info.sp_expire * 86400 self.addUser(self.buildUserObject(username, password, home_dir, uid, gid, 0, 0, 0, 0, 0, status, expiration_date)) def convertFromProFTPD(self): with open(self.input_file, 'r') as f: for line in f: fields = line.split(':') if len(fields) > 6: username = fields[0] password = fields[1] uid = int(fields[2]) gid = int(fields[3]) home_dir = fields[5] if not self.isUserValid(username, uid): continue if self.force_uid >= 0: uid = self.force_uid if self.force_gid >= 0: gid = self.force_gid self.addUser(self.buildUserObject(username, password, home_dir, uid, gid, 0, 0, 0, 0, 0, 1, 0)) def convertPureFTPDIP(self, fields): result = [] if not fields: return result for v in fields.split(','): ip_mask = v.strip() if not ip_mask: continue if ip_mask.count('.') < 3 and ip_mask.count(':') < 3: print('cannot import pure-ftpd IP: {}'.format(ip_mask)) continue if '/' not in ip_mask: ip_mask += '/32' result.append(ip_mask) return result def convertFromPureFTPD(self): with open(self.input_file, 'r') as f: for line in f: fields = line.split(':') if len(fields) > 16: username = fields[0] password = fields[1] uid = int(fields[2]) gid = int(fields[3]) home_dir = fields[5] upload_bandwidth = 0 if fields[6]: upload_bandwidth = int(int(fields[6]) / 1024) download_bandwidth = 0 if fields[7]: download_bandwidth = int(int(fields[7]) / 1024) max_sessions = 0 if fields[10]: max_sessions = int(fields[10]) quota_files = 0 if fields[11]: quota_files = int(fields[11]) quota_size = 0 if fields[12]: quota_size = int(fields[12]) allowed_ip = self.convertPureFTPDIP(fields[15]) denied_ip = self.convertPureFTPDIP(fields[16]) if not self.isUserValid(username, uid): continue if self.force_uid >= 0: uid = self.force_uid if self.force_gid >= 0: gid = self.force_gid self.addUser(self.buildUserObject(username, password, home_dir, uid, gid, max_sessions, quota_size, quota_files, upload_bandwidth, download_bandwidth, 1, 0, allowed_ip, denied_ip)) if __name__ == '__main__': parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description= 'Convert users to a JSON format suitable to use with loadddata') supportedUsersFormats = [] help_text = '' if pwd is not None: supportedUsersFormats.append('unix-passwd') help_text = 'To import from unix-passwd format you need the permission to read /etc/shadow that is typically granted to the root user only' supportedUsersFormats.append('pure-ftpd') supportedUsersFormats.append('proftpd') parser.add_argument('input_file', type=str) parser.add_argument('users_format', type=str, choices=supportedUsersFormats, help=help_text) parser.add_argument('output_file', type=str) parser.add_argument('--min-uid', type=int, default=-1, help='if >= 0 only import users with UID greater or equal ' + 'to this value. Default: %(default)s') parser.add_argument('--max-uid', type=int, default=-1, help='if >= 0 only import users with UID lesser or equal ' + 'to this value. Default: %(default)s') parser.add_argument('--usernames', type=str, nargs='+', default=[], help='Only import users with these usernames. ' + 'Default: %(default)s') parser.add_argument('--force-uid', type=int, default=-1, help='if >= 0 the imported users will have this UID in ' + 'SFTPGo. Default: %(default)s') parser.add_argument('--force-gid', type=int, default=-1, help='if >= 0 the imported users will have this GID in ' + 'SFTPGo. Default: %(default)s') args = parser.parse_args() convertUsers = ConvertUsers(args.input_file, args.users_format, args.output_file, args.min_uid, args.max_uid, args.usernames, args.force_uid, args.force_gid) convertUsers.convert() ================================================ FILE: examples/ldapauth/README.md ================================================ # LDAPAuth This is an example for an external authentication program. It performs authentication against an LDAP server. It is tested against [389ds](https://directory.fedoraproject.org/) and can be used as starting point to authenticate using any LDAP server including Active Directory. You need to change the LDAP connection parameters and the user search query to match your environment. You can build this example using the following command: ```console go build -ldflags "-s -w" -o ldapauth ``` This program assumes that the 389ds schema was extended to add support for public keys using the following ldif file placed in `/etc/dirsrv/schema/98openssh-ldap.ldif`: ```console dn: cn=schema changetype: modify add: attributetypes attributetypes: ( 1.3.6.1.4.1.24552.500.1.1.1.13 NAME 'sshPublicKey' DESC 'MANDATORY: OpenSSH Public key' EQUALITY octetStringMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.40 ) - add: objectclasses objectClasses: ( 1.3.6.1.4.1.24552.500.1.1.2.0 NAME 'ldapPublicKey' SUP top AUXILIARY DESC 'MANDATORY: OpenSSH LPK objectclass' MUST ( uid ) MAY ( sshPublicKey ) ) - dn: cn=sshpublickey,cn=default indexes,cn=config,cn=ldbm database,cn=plugins,cn=config changetype: add cn: sshpublickey nsIndexType: eq nsIndexType: pres nsSystemIndex: false objectClass: top objectClass: nsIndex dn: cn=sshpublickey_self_manage,ou=groups,dc=example,dc=com changetype: add objectClass: top objectClass: groupofuniquenames cn: sshpublickey_self_manage description: Members of this group gain the ability to edit their own sshPublicKey field dn: dc=example,dc=com changetype: modify add: aci aci: (targetattr = "sshPublicKey") (version 3.0; acl "Allow members of sshpublickey_self_manage to edit their keys"; allow(write) (groupdn = "ldap:///cn=sshpublickey_self_manage,ou=groups,dc=example,dc=com" and userdn="ldap:///self" ); ) - ``` :warning: A plugin for LDAP/Active Directory authentication is also [available](https://github.com/sftpgo/sftpgo-plugin-auth). ================================================ FILE: examples/ldapauth/go.mod ================================================ module github.com/drakkan/ldapauth go 1.25.0 require ( github.com/go-ldap/ldap/v3 v3.4.13 golang.org/x/crypto v0.49.0 ) require ( github.com/Azure/go-ntlmssp v0.1.0 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/google/uuid v1.6.0 // indirect golang.org/x/sys v0.42.0 // indirect ) ================================================ FILE: examples/ldapauth/go.sum ================================================ github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A= github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ= github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: examples/ldapauth/main.go ================================================ package main import ( "bytes" "encoding/json" "fmt" "log" "log/syslog" "os" "strconv" "strings" "github.com/go-ldap/ldap/v3" "golang.org/x/crypto/ssh" ) const ( rootDN = "dc=example,dc=com" bindUsername = "cn=sftpgo," + rootDN bindURL = "ldap:///" // That is, the server on the default port of localhost. passwordFile = "/etc/sftpgo/admin-password.txt" // make this file readable only by the server publicDir = "/var/www/webdav/public" ) type userFilters struct { DeniedLoginMethods []string `json:"denied_login_methods,omitempty"` } type minimalSFTPGoUser struct { Status int `json:"status,omitempty"` Username string `json:"username"` HomeDir string `json:"home_dir,omitempty"` UID int `json:"uid,omitempty"` GID int `json:"gid,omitempty"` Permissions map[string][]string `json:"permissions"` Filters userFilters `json:"filters"` } func exitError() { log.Printf("exitError\n") u := minimalSFTPGoUser{ Username: "", } resp, _ := json.Marshal(u) fmt.Printf("%v\n", string(resp)) os.Exit(1) } func printSuccessResponse(username, homeDir string, uid, gid int, permissions []string) { u := minimalSFTPGoUser{ Username: username, HomeDir: homeDir, UID: uid, GID: gid, Status: 1, } u.Permissions = make(map[string][]string) u.Permissions["/"] = permissions // uncomment the next line to require publickey+password authentication //u.Filters.DeniedLoginMethods = []string{"publickey", "password", "keyboard-interactive", "publickey+keyboard-interactive"} resp, _ := json.Marshal(u) log.Printf("%v\n", string(resp)) fmt.Printf("%v\n", string(resp)) os.Exit(0) } func main() { logWriter, err := syslog.New(syslog.LOG_NOTICE, "sftpgo") if err == nil { log.SetOutput(logWriter) } // get credentials from env vars username := os.Getenv("SFTPGO_AUTHD_USERNAME") password := os.Getenv("SFTPGO_AUTHD_PASSWORD") publickey := os.Getenv("SFTPGO_AUTHD_PUBLIC_KEY") if strings.ToLower(username) == "anonymous" { printSuccessResponse("anonymous", publicDir, 0, 0, []string{"list", "download"}) return } l, err := ldap.DialURL(bindURL) if err != nil { log.Printf("DialURL: %s\n", err.Error()) exitError() } defer l.Close() // bind to the ldap server with an account that can read users bindPassword, err := os.ReadFile(passwordFile) if err != nil { log.Printf("ReadFile(%s): %s\n", passwordFile, err.Error()) exitError() } err = l.Bind(bindUsername, string(bindPassword)) if err != nil { log.Printf("Bind(%s): %s\n", bindUsername, err.Error()) exitError() } // search the user trying to login and fetch some attributes, this search string is tested against 389ds using the default configuration log.Printf("username=%s\n", username) searchFilter := fmt.Sprintf("(uid=%s)", ldap.EscapeFilter(username)) searchRequest := ldap.NewSearchRequest( "ou=people," + rootDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, searchFilter, []string{"dn", "uid", "homeDirectory", "uidNumber", "gidNumber", "nsSshPublicKey"}, nil, ) sr, err := l.Search(searchRequest) if err != nil { log.Printf("Search(%s): %s\n", searchFilter, err.Error()) exitError() } // we expect exactly one user if len(sr.Entries) != 1 { log.Printf("Search(%s): %d entries\n", searchFilter, len(sr.Entries)) exitError() } if len(publickey) > 0 { // check public key userKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publickey)) if err != nil { log.Printf("ParseAuthorizedKey(%s): %s\n", publickey, err.Error()) exitError() } authOk := false for _, k := range sr.Entries[0].GetAttributeValues("nsSshPublicKey") { key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k)) // we skip an invalid public key stored inside the LDAP server if err != nil { continue } if bytes.Equal(key.Marshal(), userKey.Marshal()) { authOk = true break } } if !authOk { log.Printf("publickey %s !authOk\n", publickey) exitError() } } else { // bind to the LDAP server with the user dn and the given password to check the password userdn := sr.Entries[0].DN // log.Printf("password=%s\n", password) err = l.Bind(userdn, password) if err != nil { log.Printf("Bind(%s): %s\n", userdn, err.Error()) exitError() } } // People in the LDAP directory aren't necessarily Linux users; // so they might not have a uidNumber or gidNumber. uidNumber := sr.Entries[0].GetAttributeValue("uidNumber") uid, err := strconv.Atoi(uidNumber) if err != nil { //log.Printf("uid Atoi(%s) = %s\n", uidNumber, err.Error()) uid = 0 } gidNumber := sr.Entries[0].GetAttributeValue("gidNumber") gid, err := strconv.Atoi(gidNumber) if err != nil { //log.Printf("gid Atoi(%s) = %s\n", gidNumber, err.Error()) gid = 0 } homeDir := sr.Entries[0].GetAttributeValue("homeDirectory") if (len(homeDir) <= 0) { homeDir = publicDir // homeDir is a required attribute. } // return the authenticated user printSuccessResponse(sr.Entries[0].GetAttributeValue("uid"), homeDir, uid, gid, []string{"*"}) } ================================================ FILE: examples/ldapauthserver/README.md ================================================ # LDAPAuthServer This is an example for an HTTP server to use as external authentication HTTP hook. It performs authentication against an LDAP server. It is tested against [389ds](https://directory.fedoraproject.org/) and can be used as starting point to authenticate using any LDAP server including Active Directory. You can configure the server using the [ldapauth.toml](./ldapauth.toml) configuration file. You can build this example using the following command: ```console go build -ldflags "-s -w" -o ldapauthserver ``` :warning: A plugin for LDAP/Active Directory authentication is also [available](https://github.com/sftpgo/sftpgo-plugin-auth). ================================================ FILE: examples/ldapauthserver/cmd/root.go ================================================ package cmd import ( "fmt" "os" "github.com/drakkan/sftpgo/ldapauthserver/config" "github.com/drakkan/sftpgo/ldapauthserver/utils" "github.com/spf13/cobra" "github.com/spf13/viper" ) const ( logSender = "cmd" configDirFlag = "config-dir" configDirKey = "config_dir" configFileFlag = "config-file" configFileKey = "config_file" logFilePathFlag = "log-file-path" logFilePathKey = "log_file_path" logMaxSizeFlag = "log-max-size" logMaxSizeKey = "log_max_size" logMaxBackupFlag = "log-max-backups" logMaxBackupKey = "log_max_backups" logMaxAgeFlag = "log-max-age" logMaxAgeKey = "log_max_age" logCompressFlag = "log-compress" logCompressKey = "log_compress" logVerboseFlag = "log-verbose" logVerboseKey = "log_verbose" profilerFlag = "profiler" profilerKey = "profiler" defaultConfigDir = "." defaultConfigName = config.DefaultConfigName defaultLogFile = "ldapauth.log" defaultLogMaxSize = 10 defaultLogMaxBackup = 5 defaultLogMaxAge = 28 defaultLogCompress = false defaultLogVerbose = true ) var ( configDir string configFile string logFilePath string logMaxSize int logMaxBackups int logMaxAge int logCompress bool logVerbose bool rootCmd = &cobra.Command{ Use: "ldapauthserver", Short: "LDAP Authentication Server for SFTPGo", } ) func init() { version := utils.GetAppVersion() rootCmd.Flags().BoolP("version", "v", false, "") rootCmd.Version = version.GetVersionAsString() rootCmd.SetVersionTemplate(`{{printf "LDAP Authentication Server version: "}}{{printf "%s" .Version}} `) } // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { if err := rootCmd.Execute(); err != nil { fmt.Println(err) os.Exit(1) } } func addConfigFlags(cmd *cobra.Command) { viper.SetDefault(configDirKey, defaultConfigDir) viper.BindEnv(configDirKey, "LDAPAUTH_CONFIG_DIR") cmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey), `Location for the config dir. This directory should contain the "ldapauth" configuration file or the configured config-file. This flag can be set using LDAPAUTH_CONFIG_DIR env var too. `) viper.BindPFlag(configDirKey, cmd.Flags().Lookup(configDirFlag)) viper.SetDefault(configFileKey, defaultConfigName) viper.BindEnv(configFileKey, "LDAPAUTH_CONFIG_FILE") cmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey), `Name for the configuration file. It must be the name of a file stored in config-dir not the absolute path to the configuration file. The specified file name must have no extension we automatically load JSON, YAML, TOML, HCL and Java properties. Therefore if you set \"ldapauth\" then \"ldapauth.toml\", \"ldapauth.yaml\" and so on are searched. This flag can be set using LDAPAUTH_CONFIG_FILE env var too. `) viper.BindPFlag(configFileKey, cmd.Flags().Lookup(configFileFlag)) } func addServeFlags(cmd *cobra.Command) { addConfigFlags(cmd) viper.SetDefault(logFilePathKey, defaultLogFile) viper.BindEnv(logFilePathKey, "LDAPAUTH_LOG_FILE_PATH") cmd.Flags().StringVarP(&logFilePath, logFilePathFlag, "l", viper.GetString(logFilePathKey), `Location for the log file. Leave empty to write logs to the standard output. This flag can be set using LDAPAUTH_LOG_FILE_PATH env var too. `) viper.BindPFlag(logFilePathKey, cmd.Flags().Lookup(logFilePathFlag)) viper.SetDefault(logMaxSizeKey, defaultLogMaxSize) viper.BindEnv(logMaxSizeKey, "LDAPAUTH_LOG_MAX_SIZE") cmd.Flags().IntVarP(&logMaxSize, logMaxSizeFlag, "s", viper.GetInt(logMaxSizeKey), `Maximum size in megabytes of the log file before it gets rotated. This flag can be set using LDAPAUTH_LOG_MAX_SIZE env var too. It is unused if log-file-path is empty.`) viper.BindPFlag(logMaxSizeKey, cmd.Flags().Lookup(logMaxSizeFlag)) viper.SetDefault(logMaxBackupKey, defaultLogMaxBackup) viper.BindEnv(logMaxBackupKey, "LDAPAUTH_LOG_MAX_BACKUPS") cmd.Flags().IntVarP(&logMaxBackups, "log-max-backups", "b", viper.GetInt(logMaxBackupKey), `Maximum number of old log files to retain. This flag can be set using LDAPAUTH_LOG_MAX_BACKUPS env var too. It is unused if log-file-path is empty.`) viper.BindPFlag(logMaxBackupKey, cmd.Flags().Lookup(logMaxBackupFlag)) viper.SetDefault(logMaxAgeKey, defaultLogMaxAge) viper.BindEnv(logMaxAgeKey, "LDAPAUTH_LOG_MAX_AGE") cmd.Flags().IntVarP(&logMaxAge, "log-max-age", "a", viper.GetInt(logMaxAgeKey), `Maximum number of days to retain old log files. This flag can be set using LDAPAUTH_LOG_MAX_AGE env var too. It is unused if log-file-path is empty.`) viper.BindPFlag(logMaxAgeKey, cmd.Flags().Lookup(logMaxAgeFlag)) viper.SetDefault(logCompressKey, defaultLogCompress) viper.BindEnv(logCompressKey, "LDAPAUTH_LOG_COMPRESS") cmd.Flags().BoolVarP(&logCompress, logCompressFlag, "z", viper.GetBool(logCompressKey), `Determine if the rotated log files should be compressed using gzip. This flag can be set using LDAPAUTH_LOG_COMPRESS env var too. It is unused if log-file-path is empty.`) viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) viper.SetDefault(logVerboseKey, defaultLogVerbose) viper.BindEnv(logVerboseKey, "LDAPAUTH_LOG_VERBOSE") cmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey), `Enable verbose logs. This flag can be set using LDAPAUTH_LOG_VERBOSE env var too. `) viper.BindPFlag(logVerboseKey, cmd.Flags().Lookup(logVerboseFlag)) } ================================================ FILE: examples/ldapauthserver/cmd/serve.go ================================================ package cmd import ( "path/filepath" "github.com/drakkan/sftpgo/ldapauthserver/config" "github.com/drakkan/sftpgo/ldapauthserver/httpd" "github.com/drakkan/sftpgo/ldapauthserver/logger" "github.com/drakkan/sftpgo/ldapauthserver/utils" "github.com/rs/zerolog" "github.com/spf13/cobra" ) var ( serveCmd = &cobra.Command{ Use: "serve", Short: "Start the LDAP Authentication Server", Long: `To start the server with the default values for the command line flags simply use: ldapauthserver serve Please take a look at the usage below to customize the startup options`, Run: func(cmd *cobra.Command, args []string) { startServer() }, } ) func init() { rootCmd.AddCommand(serveCmd) addServeFlags(serveCmd) } func startServer() error { logLevel := zerolog.DebugLevel if !logVerbose { logLevel = zerolog.InfoLevel } if !filepath.IsAbs(logFilePath) && utils.IsFileInputValid(logFilePath) { logFilePath = filepath.Join(configDir, logFilePath) } logger.InitLogger(logFilePath, logMaxSize, logMaxBackups, logMaxAge, logCompress, logLevel) version := utils.GetAppVersion() logger.Info(logSender, "", "starting LDAP Auth Server %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+ "log max age: %v log verbose: %v, log compress: %v", version.GetVersionAsString(), configDir, configFile, logMaxSize, logMaxBackups, logMaxAge, logVerbose, logCompress) config.LoadConfig(configDir, configFile) return httpd.StartHTTPServer(configDir, config.GetHTTPDConfig()) } ================================================ FILE: examples/ldapauthserver/config/config.go ================================================ package config import ( "strings" "github.com/drakkan/sftpgo/ldapauthserver/logger" "github.com/spf13/viper" ) const ( logSender = "config" // DefaultConfigName defines the name for the default config file. // This is the file name without extension, we use viper and so we // support all the config files format supported by viper DefaultConfigName = "ldapauth" // ConfigEnvPrefix defines a prefix that ENVIRONMENT variables will use configEnvPrefix = "ldapauth" ) // HTTPDConfig defines configuration for the HTTPD server type HTTPDConfig struct { BindAddress string `mapstructure:"bind_address"` BindPort int `mapstructure:"bind_port"` AuthUserFile string `mapstructure:"auth_user_file"` CertificateFile string `mapstructure:"certificate_file"` CertificateKeyFile string `mapstructure:"certificate_key_file"` } // LDAPConfig defines the configuration parameters for LDAP connections and searches type LDAPConfig struct { BaseDN string `mapstructure:"basedn"` BindURL string `mapstructure:"bind_url"` BindUsername string `mapstructure:"bind_username"` BindPassword string `mapstructure:"bind_password"` SearchFilter string `mapstructure:"search_filter"` SearchBaseAttrs []string `mapstructure:"search_base_attrs"` DefaultUID int `mapstructure:"default_uid"` DefaultGID int `mapstructure:"default_gid"` ForceDefaultUID bool `mapstructure:"force_default_uid"` ForceDefaultGID bool `mapstructure:"force_default_gid"` InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` CACertificates []string `mapstructure:"ca_certificates"` } type appConfig struct { HTTPD HTTPDConfig `mapstructure:"httpd"` LDAP LDAPConfig `mapstructure:"ldap"` } var conf appConfig func init() { conf = appConfig{ HTTPD: HTTPDConfig{ BindAddress: "", BindPort: 9000, AuthUserFile: "", CertificateFile: "", CertificateKeyFile: "", }, LDAP: LDAPConfig{ BaseDN: "dc=example,dc=com", BindURL: "ldap://192.168.1.103:389", BindUsername: "cn=Directory Manager", BindPassword: "YOUR_ADMIN_PASSWORD_HERE", SearchFilter: "(&(objectClass=nsPerson)(uid=%s))", SearchBaseAttrs: []string{ "dn", "homeDirectory", "uidNumber", "gidNumber", "nsSshPublicKey", }, DefaultUID: 0, DefaultGID: 0, ForceDefaultUID: true, ForceDefaultGID: true, InsecureSkipVerify: false, CACertificates: nil, }, } viper.SetEnvPrefix(configEnvPrefix) replacer := strings.NewReplacer(".", "__") viper.SetEnvKeyReplacer(replacer) viper.SetConfigName(DefaultConfigName) viper.AutomaticEnv() viper.AllowEmptyEnv(true) } // GetHomeDirectory returns the configured name for the LDAP field to use as home directory func (l *LDAPConfig) GetHomeDirectory() string { if len(l.SearchBaseAttrs) > 1 { return l.SearchBaseAttrs[1] } return "homeDirectory" } // GetUIDNumber returns the configured name for the LDAP field to use as UID func (l *LDAPConfig) GetUIDNumber() string { if len(l.SearchBaseAttrs) > 2 { return l.SearchBaseAttrs[2] } return "uidNumber" } // GetGIDNumber returns the configured name for the LDAP field to use as GID func (l *LDAPConfig) GetGIDNumber() string { if len(l.SearchBaseAttrs) > 3 { return l.SearchBaseAttrs[3] } return "gidNumber" } // GetPublicKey returns the configured name for the LDAP field to use as public keys func (l *LDAPConfig) GetPublicKey() string { if len(l.SearchBaseAttrs) > 4 { return l.SearchBaseAttrs[4] } return "nsSshPublicKey" } // GetHTTPDConfig returns the configuration for the HTTP server func GetHTTPDConfig() HTTPDConfig { return conf.HTTPD } // GetLDAPConfig returns LDAP related settings func GetLDAPConfig() LDAPConfig { return conf.LDAP } func getRedactedConf() appConfig { c := conf return c } // LoadConfig loads the configuration func LoadConfig(configDir, configName string) error { var err error viper.AddConfigPath(configDir) viper.AddConfigPath(".") viper.SetConfigName(configName) if err = viper.ReadInConfig(); err != nil { logger.Warn(logSender, "", "error loading configuration file: %v. Default configuration will be used: %+v", err, getRedactedConf()) logger.WarnToConsole("error loading configuration file: %v. Default configuration will be used.", err) return err } err = viper.Unmarshal(&conf) if err != nil { logger.Warn(logSender, "", "error parsing configuration file: %v. Default configuration will be used: %+v", err, getRedactedConf()) logger.WarnToConsole("error parsing configuration file: %v. Default configuration will be used.", err) return err } logger.Debug(logSender, "", "config file used: '%q', config loaded: %+v", viper.ConfigFileUsed(), getRedactedConf()) return err } ================================================ FILE: examples/ldapauthserver/go.mod ================================================ module github.com/drakkan/sftpgo/ldapauthserver go 1.25.0 require ( github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/render v1.0.3 github.com/go-ldap/ldap/v3 v3.4.13 github.com/nathanaelle/password/v2 v2.0.1 github.com/rs/zerolog v1.34.0 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 golang.org/x/crypto v0.49.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) require ( github.com/Azure/go-ntlmssp v0.1.0 // indirect github.com/ajg/form v1.7.1 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) ================================================ FILE: examples/ldapauthserver/go.sum ================================================ github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A= github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajg/form v1.7.1 h1:OsnBDzTkrWdrxvEnO68I72ZVGJGNaMwPhoAm0V+llgc= github.com/ajg/form v1.7.1/go.mod h1:HL757PzLyNkj5AIfptT6L+iGNeXTlnrr/oDePGc/y7Q= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ= github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0= github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/nathanaelle/password/v2 v2.0.1 h1:ItoCTdsuIWzilYmllQPa3DR3YoCXcpfxScWLqr8Ii2s= github.com/nathanaelle/password/v2 v2.0.1/go.mod h1:eaoT+ICQEPNtikBRIAatN8ThWwMhVG+r1jTw60BvPJk= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: examples/ldapauthserver/httpd/auth.go ================================================ package httpd import ( "encoding/csv" "errors" "fmt" "net/http" "os" "sync" unixcrypt "github.com/nathanaelle/password/v2" "github.com/drakkan/sftpgo/ldapauthserver/logger" "github.com/drakkan/sftpgo/ldapauthserver/utils" "golang.org/x/crypto/bcrypt" ) const ( authenticationHeader = "WWW-Authenticate" authenticationRealm = "LDAP Auth Server" unauthResponse = "Unauthorized" ) var ( md5CryptPwdPrefixes = []string{"$1$", "$apr1$"} bcryptPwdPrefixes = []string{"$2a$", "$2$", "$2x$", "$2y$", "$2b$"} ) type httpAuthProvider interface { getHashedPassword(username string) (string, bool) isEnabled() bool } type basicAuthProvider struct { Path string sync.RWMutex Info os.FileInfo Users map[string]string } func newBasicAuthProvider(authUserFile string) (httpAuthProvider, error) { basicAuthProvider := basicAuthProvider{ Path: authUserFile, Info: nil, Users: make(map[string]string), } return &basicAuthProvider, basicAuthProvider.loadUsers() } func (p *basicAuthProvider) isEnabled() bool { return len(p.Path) > 0 } func (p *basicAuthProvider) isReloadNeeded(info os.FileInfo) bool { p.RLock() defer p.RUnlock() return p.Info == nil || p.Info.ModTime() != info.ModTime() || p.Info.Size() != info.Size() } func (p *basicAuthProvider) loadUsers() error { if !p.isEnabled() { return nil } info, err := os.Stat(p.Path) if err != nil { logger.Debug(logSender, "", "unable to stat basic auth users file: %v", err) return err } if p.isReloadNeeded(info) { r, err := os.Open(p.Path) if err != nil { logger.Debug(logSender, "", "unable to open basic auth users file: %v", err) return err } defer r.Close() reader := csv.NewReader(r) reader.Comma = ':' reader.Comment = '#' reader.TrimLeadingSpace = true records, err := reader.ReadAll() if err != nil { logger.Debug(logSender, "", "unable to parse basic auth users file: %v", err) return err } p.Lock() defer p.Unlock() p.Users = make(map[string]string) for _, record := range records { if len(record) == 2 { p.Users[record[0]] = record[1] } } logger.Debug(logSender, "", "number of users loaded for httpd basic auth: %v", len(p.Users)) p.Info = info } return nil } func (p *basicAuthProvider) getHashedPassword(username string) (string, bool) { err := p.loadUsers() if err != nil { return "", false } p.RLock() defer p.RUnlock() pwd, ok := p.Users[username] return pwd, ok } func checkAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !validateCredentials(r) { w.Header().Set(authenticationHeader, fmt.Sprintf("Basic realm=\"%v\"", authenticationRealm)) sendAPIResponse(w, r, errors.New(unauthResponse), "", http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } func validateCredentials(r *http.Request) bool { if !httpAuth.isEnabled() { return true } username, password, ok := r.BasicAuth() if !ok { return false } if hashedPwd, ok := httpAuth.getHashedPassword(username); ok { if utils.IsStringPrefixInSlice(hashedPwd, bcryptPwdPrefixes) { err := bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(password)) return err == nil } if utils.IsStringPrefixInSlice(hashedPwd, md5CryptPwdPrefixes) { crypter, ok := unixcrypt.MD5.CrypterFound(hashedPwd) if !ok { err := errors.New("cannot found matching MD5 crypter") logger.Debug(logSender, "", "error comparing password with MD5 crypt hash: %v", err) return false } return crypter.Verify([]byte(password)) } } return false } ================================================ FILE: examples/ldapauthserver/httpd/httpd.go ================================================ package httpd import ( "context" "crypto/tls" "crypto/x509" "fmt" "net/http" "os" "path/filepath" "time" "github.com/drakkan/sftpgo/ldapauthserver/config" "github.com/drakkan/sftpgo/ldapauthserver/logger" "github.com/drakkan/sftpgo/ldapauthserver/utils" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" ) const ( logSender = "httpd" versionPath = "/api/v1/version" checkAuthPath = "/api/v1/check_auth" maxRequestSize = 1 << 18 // 256KB ) var ( ldapConfig config.LDAPConfig httpAuth httpAuthProvider certMgr *certManager rootCAs *x509.CertPool ) // StartHTTPServer initializes and starts the HTTP Server func StartHTTPServer(configDir string, httpConfig config.HTTPDConfig) error { var err error authUserFile := getConfigPath(httpConfig.AuthUserFile, configDir) httpAuth, err = newBasicAuthProvider(authUserFile) if err != nil { return err } router := chi.NewRouter() router.Use(middleware.RequestID) router.Use(middleware.RealIP) router.Use(logger.NewStructuredLogger(logger.GetLogger())) router.Use(middleware.Recoverer) router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) })) router.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sendAPIResponse(w, r, nil, "Method not allowed", http.StatusMethodNotAllowed) })) router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, utils.GetAppVersion()) }) router.Group(func(router chi.Router) { router.Use(checkAuth) router.Post(checkAuthPath, checkSFTPGoUserAuth) }) ldapConfig = config.GetLDAPConfig() loadCACerts(configDir) certificateFile := getConfigPath(httpConfig.CertificateFile, configDir) certificateKeyFile := getConfigPath(httpConfig.CertificateKeyFile, configDir) httpServer := &http.Server{ Addr: fmt.Sprintf("%s:%d", httpConfig.BindAddress, httpConfig.BindPort), Handler: router, ReadTimeout: 70 * time.Second, WriteTimeout: 70 * time.Second, IdleTimeout: 120 * time.Second, MaxHeaderBytes: 1 << 16, // 64KB } if len(certificateFile) > 0 && len(certificateKeyFile) > 0 { certMgr, err = newCertManager(certificateFile, certificateKeyFile) if err != nil { return err } config := &tls.Config{ GetCertificate: certMgr.GetCertificateFunc(), MinVersion: tls.VersionTLS12, } httpServer.TLSConfig = config return httpServer.ListenAndServeTLS("", "") } return httpServer.ListenAndServe() } func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) { var errorString string if err != nil { errorString = err.Error() } resp := apiResponse{ Error: errorString, Message: message, HTTPStatus: code, } ctx := context.WithValue(r.Context(), render.StatusCtxKey, code) render.JSON(w, r.WithContext(ctx), resp) } func loadCACerts(configDir string) error { var err error rootCAs, err = x509.SystemCertPool() if err != nil { rootCAs = x509.NewCertPool() } for _, ca := range ldapConfig.CACertificates { caPath := getConfigPath(ca, configDir) certs, err := os.ReadFile(caPath) if err != nil { logger.Warn(logSender, "", "error loading ca cert %q: %v", caPath, err) return err } if !rootCAs.AppendCertsFromPEM(certs) { logger.Warn(logSender, "", "unable to add ca cert %q", caPath) } else { logger.Debug(logSender, "", "ca cert %q added to the trusted certificates", caPath) } } return nil } // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths func ReloadTLSCertificate() { if certMgr != nil { certMgr.loadCertificate() } } func getConfigPath(name, configDir string) string { if !utils.IsFileInputValid(name) { return "" } if len(name) > 0 && !filepath.IsAbs(name) { return filepath.Join(configDir, name) } return name } ================================================ FILE: examples/ldapauthserver/httpd/ldapauth.go ================================================ package httpd import ( "bytes" "crypto/tls" "fmt" "net/http" "strconv" "strings" "github.com/drakkan/sftpgo/ldapauthserver/logger" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" "github.com/go-ldap/ldap/v3" "golang.org/x/crypto/ssh" ) func getSFTPGoUser(entry *ldap.Entry, username string) (SFTPGoUser, error) { var err error var user SFTPGoUser uid := ldapConfig.DefaultUID gid := ldapConfig.DefaultGID status := 1 if !ldapConfig.ForceDefaultUID { uid, err = strconv.Atoi(entry.GetAttributeValue(ldapConfig.GetUIDNumber())) if err != nil { return user, err } } if !ldapConfig.ForceDefaultGID { uid, err = strconv.Atoi(entry.GetAttributeValue(ldapConfig.GetGIDNumber())) if err != nil { return user, err } } sftpgoUser := SFTPGoUser{ Username: username, HomeDir: entry.GetAttributeValue(ldapConfig.GetHomeDirectory()), UID: uid, GID: gid, Status: status, } sftpgoUser.Permissions = make(map[string][]string) sftpgoUser.Permissions["/"] = []string{"*"} return sftpgoUser, nil } func checkSFTPGoUserAuth(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var authReq externalAuthRequest err := render.DecodeJSON(r.Body, &authReq) if err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "error decoding auth request: %v", err) sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } l, err := ldap.DialURL(ldapConfig.BindURL, ldap.DialWithTLSConfig(&tls.Config{ InsecureSkipVerify: ldapConfig.InsecureSkipVerify, RootCAs: rootCAs, })) if err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "error connecting to the LDAP server: %v", err) sendAPIResponse(w, r, err, "Error connecting to the LDAP server", http.StatusInternalServerError) return } defer l.Close() err = l.Bind(ldapConfig.BindUsername, ldapConfig.BindPassword) if err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "error binding to the LDAP server: %v", err) sendAPIResponse(w, r, err, "Error binding to the LDAP server", http.StatusInternalServerError) return } searchRequest := ldap.NewSearchRequest( ldapConfig.BaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, strings.Replace(ldapConfig.SearchFilter, "%s", ldap.EscapeFilter(authReq.Username), 1), ldapConfig.SearchBaseAttrs, nil, ) sr, err := l.Search(searchRequest) if err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "error searching LDAP user %q: %v", authReq.Username, err) sendAPIResponse(w, r, err, "Error searching LDAP user", http.StatusInternalServerError) return } if len(sr.Entries) != 1 { logger.Warn(logSender, middleware.GetReqID(r.Context()), "expected one user, found: %v", len(sr.Entries)) sendAPIResponse(w, r, nil, fmt.Sprintf("Expected one user, found: %v", len(sr.Entries)), http.StatusNotFound) return } if len(authReq.PublicKey) > 0 { userKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(authReq.PublicKey)) if err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "invalid public key for user %q: %v", authReq.Username, err) sendAPIResponse(w, r, err, "Invalid public key", http.StatusBadRequest) return } authOk := false for _, k := range sr.Entries[0].GetAttributeValues(ldapConfig.GetPublicKey()) { key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k)) // we skip an invalid public key stored inside the LDAP server if err != nil { continue } if bytes.Equal(key.Marshal(), userKey.Marshal()) { authOk = true break } } if !authOk { logger.Warn(logSender, middleware.GetReqID(r.Context()), "public key authentication failed for user: %q", authReq.Username) sendAPIResponse(w, r, nil, "public key authentication failed", http.StatusForbidden) return } } else { // bind to the LDAP server with the user dn and the given password to check the password userdn := sr.Entries[0].DN err = l.Bind(userdn, authReq.Password) if err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "password authentication failed for user: %q", authReq.Username) sendAPIResponse(w, r, nil, "password authentication failed", http.StatusForbidden) return } } user, err := getSFTPGoUser(sr.Entries[0], authReq.Username) if err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "get user from LDAP entry failed for username %q: %v", authReq.Username, err) sendAPIResponse(w, r, err, "mapping LDAP user failed", http.StatusInternalServerError) return } render.JSON(w, r, user) } ================================================ FILE: examples/ldapauthserver/httpd/models.go ================================================ package httpd type apiResponse struct { Error string `json:"error"` Message string `json:"message"` HTTPStatus int `json:"status"` } type externalAuthRequest struct { Username string `json:"username"` Password string `json:"password"` PublicKey string `json:"public_key"` } // SFTPGoExtensionsFilter defines filters based on file extensions type SFTPGoExtensionsFilter struct { Path string `json:"path"` AllowedExtensions []string `json:"allowed_extensions,omitempty"` DeniedExtensions []string `json:"denied_extensions,omitempty"` } // SFTPGoUserFilters defines additional restrictions for an SFTPGo user type SFTPGoUserFilters struct { AllowedIP []string `json:"allowed_ip,omitempty"` DeniedIP []string `json:"denied_ip,omitempty"` DeniedLoginMethods []string `json:"denied_login_methods,omitempty"` FileExtensions []SFTPGoExtensionsFilter `json:"file_extensions,omitempty"` } // S3FsConfig defines the configuration for S3 based filesystem type S3FsConfig struct { Bucket string `json:"bucket,omitempty"` KeyPrefix string `json:"key_prefix,omitempty"` Region string `json:"region,omitempty"` AccessKey string `json:"access_key,omitempty"` AccessSecret string `json:"access_secret,omitempty"` Endpoint string `json:"endpoint,omitempty"` StorageClass string `json:"storage_class,omitempty"` UploadPartSize int64 `json:"upload_part_size,omitempty"` UploadConcurrency int `json:"upload_concurrency,omitempty"` } // GCSFsConfig defines the configuration for Google Cloud Storage based filesystem type GCSFsConfig struct { Bucket string `json:"bucket,omitempty"` KeyPrefix string `json:"key_prefix,omitempty"` Credentials string `json:"credentials,omitempty"` AutomaticCredentials int `json:"automatic_credentials,omitempty"` StorageClass string `json:"storage_class,omitempty"` } // SFTPGoFilesystem defines cloud storage filesystem details type SFTPGoFilesystem struct { // 0 local filesystem, 1 AWS S3 compatible, 2 Google Cloud Storage Provider int `json:"provider"` S3Config S3FsConfig `json:"s3config,omitempty"` GCSConfig GCSFsConfig `json:"gcsconfig,omitempty"` } type virtualFolder struct { VirtualPath string `json:"virtual_path"` MappedPath string `json:"mapped_path"` } // SFTPGoUser defines an SFTPGo user type SFTPGoUser struct { // Database unique identifier ID int64 `json:"id"` // 1 enabled, 0 disabled (login is not allowed) Status int `json:"status"` // Username Username string `json:"username"` // Account expiration date as unix timestamp in milliseconds. An expired account cannot login. // 0 means no expiration ExpirationDate int64 `json:"expiration_date"` Password string `json:"password,omitempty"` PublicKeys []string `json:"public_keys,omitempty"` HomeDir string `json:"home_dir"` // Mapping between virtual paths and filesystem paths outside the home directory. Supported for local filesystem only VirtualFolders []virtualFolder `json:"virtual_folders,omitempty"` // If sftpgo runs as root system user then the created files and directories will be assigned to this system UID UID int `json:"uid"` // If sftpgo runs as root system user then the created files and directories will be assigned to this system GID GID int `json:"gid"` // Maximum concurrent sessions. 0 means unlimited MaxSessions int `json:"max_sessions"` // Maximum size allowed as bytes. 0 means unlimited QuotaSize int64 `json:"quota_size"` // Maximum number of files allowed. 0 means unlimited QuotaFiles int `json:"quota_files"` // List of the granted permissions Permissions map[string][]string `json:"permissions"` // Used quota as bytes UsedQuotaSize int64 `json:"used_quota_size"` // Used quota as number of files UsedQuotaFiles int `json:"used_quota_files"` // Last quota update as unix timestamp in milliseconds LastQuotaUpdate int64 `json:"last_quota_update"` // Maximum upload bandwidth as KB/s, 0 means unlimited UploadBandwidth int64 `json:"upload_bandwidth"` // Maximum download bandwidth as KB/s, 0 means unlimited DownloadBandwidth int64 `json:"download_bandwidth"` // Last login as unix timestamp in milliseconds LastLogin int64 `json:"last_login"` // Additional restrictions Filters SFTPGoUserFilters `json:"filters"` // Filesystem configuration details FsConfig SFTPGoFilesystem `json:"filesystem"` } ================================================ FILE: examples/ldapauthserver/httpd/tlsutils.go ================================================ package httpd import ( "crypto/tls" "sync" "github.com/drakkan/sftpgo/ldapauthserver/logger" ) type certManager struct { certPath string keyPath string sync.RWMutex cert *tls.Certificate } func (m *certManager) loadCertificate() error { newCert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath) if err != nil { logger.Warn(logSender, "", "unable to load https certificate: %v", err) return err } logger.Debug(logSender, "", "https certificate successfully loaded") m.Lock() defer m.Unlock() m.cert = &newCert return nil } func (m *certManager) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { m.RLock() defer m.RUnlock() return m.cert, nil } } func newCertManager(certificateFile, certificateKeyFile string) (*certManager, error) { manager := &certManager{ cert: nil, certPath: certificateFile, keyPath: certificateKeyFile, } err := manager.loadCertificate() if err != nil { return nil, err } return manager, nil } ================================================ FILE: examples/ldapauthserver/ldapauth.toml ================================================ [httpd] bind_address = "" bind_port = 9000 # Path to a file used to store usernames and passwords for basic authentication. It can be generated using the Apache htpasswd tool auth_user_file = "" # If both the certificate and the private key are provided, the server will expect HTTPS connections certificate_file = "" certificate_key_file = "" [ldap] basedn = "dc=example,dc=com" bind_url = "ldap://127.0.0.1:389" bind_username = "cn=Directory Manager" bind_password = "YOUR_ADMIN_PASSWORD_HERE" search_filter = "(&(objectClass=nsPerson)(uid=%s))" # you can change the name of the search base attributes to adapt them to your schema but the order must remain the same search_base_attrs = [ "dn", "homeDirectory", "uidNumber", "gidNumber", "nsSshPublicKey" ] default_uid = 0 default_gid = 0 force_default_uid = true force_default_gid = true # if true, ldaps accepts any certificate presented by the LDAP server and any host name in that certificate. # This should be used only for testing insecure_skip_verify = false # list of root CA to use for ldaps connections # If you use a self signed certificate is better to add the root CA to this list than set insecure_skip_verify to true ca_certificates = [] ================================================ FILE: examples/ldapauthserver/logger/logger.go ================================================ package logger import ( "fmt" "os" "path/filepath" "runtime" "github.com/rs/zerolog" lumberjack "gopkg.in/natefinch/lumberjack.v2" ) const ( dateFormat = "2006-01-02T15:04:05.000" // YYYY-MM-DDTHH:MM:SS.ZZZ ) var ( logger zerolog.Logger consoleLogger zerolog.Logger ) // GetLogger get the configured logger instance func GetLogger() *zerolog.Logger { return &logger } // InitLogger initialize loggers func InitLogger(logFilePath string, logMaxSize, logMaxBackups, logMaxAge int, logCompress bool, level zerolog.Level) { zerolog.TimeFieldFormat = dateFormat if isLogFilePathValid(logFilePath) { logger = zerolog.New(&lumberjack.Logger{ Filename: logFilePath, MaxSize: logMaxSize, MaxBackups: logMaxBackups, MaxAge: logMaxAge, Compress: logCompress, }) EnableConsoleLogger(level) } else { logger = zerolog.New(&logSyncWrapper{ output: os.Stdout, }) consoleLogger = zerolog.Nop() } logger.Level(level) } // DisableLogger disable the main logger. // ConsoleLogger will not be affected func DisableLogger() { logger = zerolog.Nop() } // EnableConsoleLogger enables the console logger func EnableConsoleLogger(level zerolog.Level) { consoleOutput := zerolog.ConsoleWriter{ Out: os.Stdout, TimeFormat: dateFormat, NoColor: runtime.GOOS == "windows", } consoleLogger = zerolog.New(consoleOutput).With().Timestamp().Logger().Level(level) } // Debug logs at debug level for the specified sender func Debug(prefix, requestID string, format string, v ...interface{}) { logger.Debug(). Timestamp(). Str("sender", prefix). Str("request_id", requestID). Msg(fmt.Sprintf(format, v...)) } // Info logs at info level for the specified sender func Info(prefix, requestID string, format string, v ...interface{}) { logger.Info(). Timestamp(). Str("sender", prefix). Str("request_id", requestID). Msg(fmt.Sprintf(format, v...)) } // Warn logs at warn level for the specified sender func Warn(prefix, requestID string, format string, v ...interface{}) { logger.Warn(). Timestamp(). Str("sender", prefix). Str("request_id", requestID). Msg(fmt.Sprintf(format, v...)) } // Error logs at error level for the specified sender func Error(prefix, requestID string, format string, v ...interface{}) { logger.Error(). Timestamp(). Str("sender", prefix). Str("request_id", requestID). Msg(fmt.Sprintf(format, v...)) } // DebugToConsole logs at debug level to stdout func DebugToConsole(format string, v ...interface{}) { consoleLogger.Debug().Msg(fmt.Sprintf(format, v...)) } // InfoToConsole logs at info level to stdout func InfoToConsole(format string, v ...interface{}) { consoleLogger.Info().Msg(fmt.Sprintf(format, v...)) } // WarnToConsole logs at info level to stdout func WarnToConsole(format string, v ...interface{}) { consoleLogger.Warn().Msg(fmt.Sprintf(format, v...)) } // ErrorToConsole logs at error level to stdout func ErrorToConsole(format string, v ...interface{}) { consoleLogger.Error().Msg(fmt.Sprintf(format, v...)) } func isLogFilePathValid(logFilePath string) bool { cleanInput := filepath.Clean(logFilePath) if cleanInput == "." || cleanInput == ".." { return false } return true } ================================================ FILE: examples/ldapauthserver/logger/request_logger.go ================================================ package logger import ( "fmt" "net/http" "time" "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog" ) // StructuredLogger defines a simple wrapper around zerolog logger. // It implements chi.middleware.LogFormatter interface type StructuredLogger struct { Logger *zerolog.Logger } // StructuredLoggerEntry ... type StructuredLoggerEntry struct { Logger *zerolog.Logger fields map[string]interface{} } // NewStructuredLogger returns a chi.middleware.RequestLogger using our StructuredLogger. // This structured logger is called by the chi.middleware.Logger handler to log each HTTP request func NewStructuredLogger(logger *zerolog.Logger) func(next http.Handler) http.Handler { return middleware.RequestLogger(&StructuredLogger{logger}) } // NewLogEntry creates a new log entry for an HTTP request func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry { scheme := "http" if r.TLS != nil { scheme = "https" } fields := map[string]interface{}{ "remote_addr": r.RemoteAddr, "proto": r.Proto, "method": r.Method, "user_agent": r.UserAgent(), "uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)} reqID := middleware.GetReqID(r.Context()) if reqID != "" { fields["request_id"] = reqID } return &StructuredLoggerEntry{Logger: l.Logger, fields: fields} } // Write logs a new entry at the end of the HTTP request func (l *StructuredLoggerEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { l.Logger.Info(). Timestamp(). Str("sender", "httpd"). Fields(l.fields). Int("resp_status", status). Int("resp_size", bytes). Int64("elapsed_ms", elapsed.Nanoseconds()/1000000). Send() } // Panic logs panics func (l *StructuredLoggerEntry) Panic(v interface{}, stack []byte) { l.Logger.Error(). Timestamp(). Str("sender", "httpd"). Fields(l.fields). Str("stack", string(stack)). Str("panic", fmt.Sprintf("%+v", v)). Send() } ================================================ FILE: examples/ldapauthserver/logger/sync_wrapper.go ================================================ package logger import ( "os" "sync" ) type logSyncWrapper struct { sync.Mutex output *os.File } func (l *logSyncWrapper) Write(b []byte) (n int, err error) { l.Lock() defer l.Unlock() return l.output.Write(b) } ================================================ FILE: examples/ldapauthserver/main.go ================================================ package main import "github.com/drakkan/sftpgo/ldapauthserver/cmd" func main() { cmd.Execute() } ================================================ FILE: examples/ldapauthserver/utils/utils.go ================================================ package utils import ( "path/filepath" "strings" ) // IsFileInputValid returns true this is a valid file name. // This method must be used before joining a file name, generally provided as // user input, with a directory func IsFileInputValid(fileInput string) bool { cleanInput := filepath.Clean(fileInput) if cleanInput == "." || cleanInput == ".." { return false } return true } // IsStringPrefixInSlice searches a string prefix in a slice and returns true // if a matching prefix is found func IsStringPrefixInSlice(obj string, list []string) bool { for _, v := range list { if strings.HasPrefix(obj, v) { return true } } return false } ================================================ FILE: examples/ldapauthserver/utils/version.go ================================================ package utils const version = "0.1.0-dev" var ( commit = "" date = "" versionInfo VersionInfo ) // VersionInfo defines version details type VersionInfo struct { Version string `json:"version"` BuildDate string `json:"build_date"` CommitHash string `json:"commit_hash"` } func init() { versionInfo = VersionInfo{ Version: version, CommitHash: commit, BuildDate: date, } } // GetVersionAsString returns the string representation of the VersionInfo struct func (v *VersionInfo) GetVersionAsString() string { versionString := v.Version if len(v.CommitHash) > 0 { versionString += "-" + v.CommitHash } if len(v.BuildDate) > 0 { versionString += "-" + v.BuildDate } return versionString } // GetAppVersion returns VersionInfo struct func GetAppVersion() VersionInfo { return versionInfo } ================================================ FILE: examples/php-activedirectory-http-server/README.md ================================================ # SFTPGo on Windows with Active Directory Integration + Caddy Static File Server Example [![SFTPGo on Windows with Active Directory Integration + Caddy Static File Server Example](https://img.youtube.com/vi/M5UcJI8t4AI/0.jpg)](https://www.youtube.com/watch?v=M5UcJI8t4AI) This is similar to the ldapauthserver example, but is more specific to using Active Directory along with using SFTPGo on a Windows Server. The Youtube Walkthrough/Tutorial video above goes into considerable more detail, but in short, it walks through setting up SFTPGo on a new Windows Server, and enables the External Authentication feature within SFTPGo, along with my `sftpgo-ldap-http-server` project, to allow for user authentication into SFTPGo to occur through one or more Active Directory connections. Additionally, I go through using the Caddy web server, to help enable serving of static files, if this is something that would be of interest for you. To get started, you'll want to download the latest release ZIP package from the [sftpgo-ldap-http-server repository](https://github.com/orware/sftpgo-ldap-http-server). The ZIP itself contains the `sftpgo-ldap-http-server.exe` file, along with an `OpenLDAP` folder (mainly to help if you want to use TLS for your LDAP connections), and a `Data` which contains a logs folder, a configuration.example.php file, a functions.php file, and the LICENSE and README files. The video above goes through the whole process, but to get started you'll want to install SFTPGo on your server, and then extract the `sftpgo-ldap-http-server` ZIP file on the server as well into a separate folder. Then you'll want to copy the configuration.example.php file and name it `configuration.php` and begin customizing the settings (e.g. add in your own LDAP settings, along with how you may want to have your folders be created). At the very minimum you'll want to make sure that the home directories are set correctly to how you want the folders to be created for your environment (you don't have to use the virtual folders or really any of the other functionality if you don't need it). Once configured, from a command prompt window, if you are already in the same folder as where you extracted the `sftpgo-ldap-http-server` ZIP, you may simply call the `sftpgo-ldap-http-server.exe` and it should start up a simple HTTP server on Port 9001 running on localhost (the port can be adjusted via the `configuration.php` file as well). Now all you have to do is point SFTPGo's `external_auth_hook` option to point to `http://localhost:9001/` and you should be able to run some authentication tests (assuming you have all of your settings correct and there are no intermediate issues). The video above definitely goes through some troubleshooting situations you might find yourself coming across, so while it is long (at about 1 hour, 42 minutes), it may be helpful to review and avoid some issues and just to learn a bit more about SFTPGo and the integration above. ## Example Virtual Folders Configuration (Allowing for Both a Public and Private Folder) The following can be utilized if you'd like to assign your users both a Private Virtual Folder and Public Virtual Folder. By itself, the Public Virtual Folder isn't necessarily public, so keep that in mind. Only by combining things together with the Caddy web server (and Caddyfile example configuration down below) can you be successful in making the `F:\files\public` folder from the example public. ```php $virtual_folders['example'] = [ [ //"id" => 0, "name" => "private-#USERNAME#", "mapped_path" => 'F:\files\private\#USERNAME#', //"used_quota_size" => 0, //"used_quota_files" => 0, //"last_quota_update" => 0, "virtual_path" => "/_private", "quota_size" => -1, "quota_files" => -1 ], [ //"id" => 0, "name" => "public-#USERNAME#", "mapped_path" => 'F:\files\public\#USERNAME#', //"used_quota_size" => 0, //"used_quota_files" => 0, //"last_quota_update" => 0, "virtual_path" => "/_public", "quota_size" => -1, "quota_files" => -1 ] ]; ``` ## Example Connection "Output Object" Allowing For No Files in the User's Home Directory ("Root Directory") but Allowing for Files in the Public/Private Virtual Folders The magic here happens in the "permissions" value, by limiting the root/home directory to just the list/download permissions, and then allowing all permissions on the Public/Private virtual folders. ```php $connection_output_objects['example'] = [ 'status' => 1, 'username' => '', 'expiration_date' => 0, 'home_dir' => '', 'uid' => 0, 'gid' => 0, 'max_sessions' => 0, 'quota_size' => 0, 'quota_files' => 100000, 'permissions' => [ "/" => ["list", "download"], "/_private" => ["*"], "/_public" => ["*"], ], 'upload_bandwidth' => 0, 'download_bandwidth' => 0, 'filters' => [ 'allowed_ip' => [], 'denied_ip' => [], ], 'public_keys' => [], ]; ``` ## Recommended Usage of Automatic Groups Mode (Limiting by Group Prefix) The `sftpgo-ldap-http-server` project is able to automatically create virtual folders for any groups your user is a memberof if the automatic mode is turned on. However, by having a specific set of allowed prefixes defined, you can limit things to just those groups that begin with the prefixes you've listed, which can be helpful. The prefix itself will be removed from the group name when added as a virtual folder for the user. ```php // If automatic groups mode is disabled, then you have to manually add the allowed groups into $allowed_groups down below: // If enabled, then any groups you are a memberof will automatically be added in using the template below. $auto_groups_mode = true; $auto_groups_mode_virtual_folder_template = [ [ //"id" => 0, "name" => "groups-#GROUP#", "mapped_path" => 'F:\files\groups\#GROUP#', //"used_quota_size" => 0, //"used_quota_files" => 0, //"last_quota_update" => 0, "virtual_path" => "/groups/#GROUP#", "quota_size" => 0, "quota_files" => 100000 ] ]; // Used only when auto groups mode is enabled and will help prevent all your groups from being // added into SFTPGo since only groups with the prefixes defined here will be automatically added // with prefixes automatically removed when listed as a virtual folder (e.g. a group with name // "sftpgo-example" would simply become "example"). $allowed_group_prefixes = [ 'sftpgo-' ]; ``` ## Example Caddyfile Configuration You Can Adapt for Your Needs ```shell ### Re-usable snippets: (add_static_file_serving_features) { # Allow accessing files without requiring .html: try_files {path} {path}.html # Enable Static File Server and Directory Browsing: file_server browse # Enable templating functionality: templates # Enable Compression for Output: encode zstd gzip handle_errors { respond "
{http.error.status_code} {http.error.status_text}
" } } (add_hsts_headers) { header { # Enable HTTP Strict Transport Security (HSTS) to force clients to always # connect via HTTPS (do not use if only testing) Strict-Transport-Security "max-age=31536000; includeSubDomains" # Enable cross-site filter (XSS) and tell browser to block detected attacks X-XSS-Protection "1; mode=block" # Prevent some browsers from MIME-sniffing a response away from the declared Content-Type X-Content-Type-Options "nosniff" # Disallow the site to be rendered within a frame (clickjacking protection) X-Frame-Options "DENY" # keep referrer data off of HTTP connections Referrer-Policy no-referrer-when-downgrade } } (add_logging_with_path) { log { output file "{args.0}" { roll_size 100mb roll_keep 5 roll_keep_for 720h } format json #format console #format single_field common_log } } ### Site Definitions: public.example.com { # Site Root: root * F:\files\public import add_logging_with_path "F:\caddy\logs\public_example_com_access.log" import add_static_file_serving_features import add_hsts_headers } ### Reverse Proxy Definitions: webdav.example.com { reverse_proxy localhost:9000 import add_logging_with_path "F:\caddy\logs\webdav_example_com_access.log" } ``` ================================================ FILE: examples/quotascan/README.md ================================================ # Update user quota :warning: Since v2.4.0 you can use the [EventManager](https://docs.sftpgo.com/latest/eventmanager/) to schedule quota scans. The `scanuserquota` example script shows how to use the SFTPGo REST API to update the users' quota. The stored quota may be incorrect for several reasons, such as an unexpected shutdown while uploading files, temporary provider failures, files copied outside of SFTPGo, and so on. A quota scan updates the number of files and their total size for the specified user and the virtual folders, if any, included in his quota. If you want to track quotas, a scheduled quota scan is recommended. You can use this example as a starting point. The script is written in Python and has the following requirements: - python3 or python2 - python [Requests](https://requests.readthedocs.io/en/master/) module The provided example tries to connect to an SFTPGo instance running on `127.0.0.1:8080` using the following credentials: - username: `admin` - password: `password` Please edit the script according to your needs. ================================================ FILE: examples/quotascan/scanuserquota ================================================ #!/usr/bin/env python from datetime import datetime import sys import time import pytz import requests try: import urllib.parse as urlparse except ImportError: import urlparse # change base_url to point to your SFTPGo installation base_url = "http://127.0.0.1:8080" # set to False if you want to skip TLS certificate validation verify_tls_cert = True # set the credentials for a valid admin here admin_user = "admin" admin_password = "password" # set your update conditions here def needQuotaUpdate(user): if user["status"] == 0: # inactive user return False if user["quota_size"] == 0 and user["quota_files"] == 0: # no quota restrictions return False return True class UpdateQuota: def __init__(self): self.limit = 100 self.offset = 0 self.access_token = "" self.access_token_expiration = None def printLog(self, message): print("{} - {}".format(datetime.now(), message)) def checkAccessToken(self): if self.access_token != "" and self.access_token_expiration: expire_diff = self.access_token_expiration - datetime.now(tz=pytz.UTC) # we don't use total_seconds to be python 2 compatible seconds_to_expire = expire_diff.days * 86400 + expire_diff.seconds if seconds_to_expire > 180: return auth = requests.auth.HTTPBasicAuth(admin_user, admin_password) r = requests.get(urlparse.urljoin(base_url, "api/v2/token"), auth=auth, verify=verify_tls_cert, timeout=10) if r.status_code != 200: self.printLog("error getting access token: {}".format(r.text)) sys.exit(1) self.access_token = r.json()["access_token"] self.access_token_expiration = pytz.timezone("UTC").localize(datetime.strptime(r.json()["expires_at"], "%Y-%m-%dT%H:%M:%SZ")) def getAuthHeader(self): self.checkAccessToken() return {"Authorization": "Bearer " + self.access_token} def waitForQuotaUpdate(self, username): while True: auth_header = self.getAuthHeader() r = requests.get(urlparse.urljoin(base_url, "api/v2/quotas/users/scans"), headers=auth_header, verify=verify_tls_cert, timeout=10) if r.status_code != 200: self.printLog("error getting quota scans while waiting for {}: {}".format(username, r.text)) sys.exit(1) scanning = False for scan in r.json(): if scan["username"] == username: scanning = True if not scanning: break self.printLog("waiting for the quota scan to complete for user {}".format(username)) time.sleep(2) self.printLog("quota update for user {} finished".format(username)) def updateUserQuota(self, username): self.printLog("starting quota update for user {}".format(username)) auth_header = self.getAuthHeader() r = requests.post(urlparse.urljoin(base_url, "api/v2/quotas/users/" + username + "/scan"), headers=auth_header, verify=verify_tls_cert, timeout=10) if r.status_code != 202: self.printLog("error starting quota scan for user {}: {}".format(username, r.text)) sys.exit(1) self.waitForQuotaUpdate(username) def updateUsersQuota(self): while True: self.printLog("get users, limit {} offset {}".format(self.limit, self.offset)) auth_header = self.getAuthHeader() payload = {"limit":self.limit, "offset":self.offset} r = requests.get(urlparse.urljoin(base_url, "api/v2/users"), headers=auth_header, params=payload, verify=verify_tls_cert, timeout=10) if r.status_code != 200: self.printLog("error getting users: {}".format(r.text)) sys.exit(1) users = r.json() for user in users: if needQuotaUpdate(user): self.updateUserQuota(user["username"]) else: self.printLog("user {} does not need a quota update".format(user["username"])) self.offset += len(users) if len(users) < self.limit: break if __name__ == '__main__': q = UpdateQuota() q.updateUsersQuota() ================================================ FILE: go.mod ================================================ module github.com/drakkan/sftpgo/v2 go 1.25.0 require ( cloud.google.com/go/storage v1.60.0 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 github.com/alexedwards/argon2id v1.0.0 github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964 github.com/aws/aws-sdk-go-v2 v1.41.4 github.com/aws/aws-sdk-go-v2/config v1.32.12 github.com/aws/aws-sdk-go-v2/credentials v1.19.12 github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 github.com/bmatcuk/doublestar/v4 v4.10.0 github.com/cockroachdb/cockroach-go/v2 v2.4.3 github.com/coreos/go-oidc/v3 v3.17.0 github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b github.com/eikenb/pipeat v0.0.0-20251030185646-385cd3c3e07b github.com/fclairamb/ftpserverlib v0.30.0 github.com/go-acme/lego/v4 v4.33.0 github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/render v1.0.3 github.com/go-jose/go-jose/v4 v4.1.3 github.com/go-sql-driver/mysql v1.9.3 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.6.0 github.com/hashicorp/go-hclog v1.6.3 github.com/hashicorp/go-plugin v1.7.0 github.com/hashicorp/go-retryablehttp v0.7.8 github.com/jackc/pgx/v5 v5.8.0 github.com/jlaffaye/ftp v0.2.0 github.com/klauspost/compress v1.18.4 github.com/lithammer/shortuuid/v4 v4.2.0 github.com/mattn/go-sqlite3 v1.14.37 github.com/mhale/smtpd v0.8.3 github.com/minio/sio v0.4.3 github.com/otiai10/copy v1.14.1 github.com/pires/go-proxyproto v0.11.0 github.com/pkg/sftp v1.13.10 github.com/pquerna/otp v1.5.0 github.com/prometheus/client_golang v1.23.2 github.com/robfig/cron/v3 v3.0.1 github.com/rs/cors v1.11.1 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/sftpgo/sdk v0.1.9 github.com/shirou/gopsutil/v3 v3.24.5 github.com/spf13/afero v1.15.0 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/studio-b12/gowebdav v0.12.0 github.com/subosito/gotenv v1.6.0 github.com/unrolled/secure v1.17.0 github.com/wagslane/go-password-validator v0.3.0 github.com/wneessen/go-mail v0.7.2 github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a go.etcd.io/bbolt v1.4.3 gocloud.dev v0.45.0 golang.org/x/crypto v0.49.0 golang.org/x/net v0.52.0 golang.org/x/oauth2 v0.36.0 golang.org/x/sys v0.42.0 golang.org/x/term v0.41.0 golang.org/x/time v0.15.0 google.golang.org/api v0.272.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) require ( cel.dev/expr v0.25.1 // indirect cloud.google.com/go v0.123.0 // indirect cloud.google.com/go/auth v0.18.2 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/iam v1.5.3 // indirect cloud.google.com/go/monitoring v1.24.3 // indirect filippo.io/edwards25519 v1.2.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.7.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect github.com/ajg/form v1.7.1 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect github.com/aws/smithy-go v1.24.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/boombuler/barcode v1.1.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect github.com/googleapis/gax-go/v2 v2.19.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/yamux v0.1.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/kr/fs v0.1.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/miekg/dns v1.1.72 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/oklog/run v1.2.0 // indirect github.com/otiai10/mint v1.6.3 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.20.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/shoenig/go-m1cpu v0.2.1 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect github.com/tklauser/go-sysconf v0.3.16 // indirect github.com/tklauser/numcpus v0.11.0 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.42.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect go.opentelemetry.io/otel v1.42.0 // indirect go.opentelemetry.io/otel/metric v1.42.0 // indirect go.opentelemetry.io/otel/sdk v1.42.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect go.opentelemetry.io/otel/trace v1.42.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/mod v0.34.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/text v0.35.0 // indirect golang.org/x/tools v0.43.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/genproto v0.0.0-20260319171110-e3a33c96fb44 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260319171110-e3a33c96fb44 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260319171110-e3a33c96fb44 // indirect google.golang.org/grpc v1.79.3 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) replace ( github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 ) ================================================ FILE: go.sum ================================================ cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM= cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= cloud.google.com/go/kms v1.26.0 h1:cK9mN2cf+9V63D3H1f6koxTatWy39aTI/hCjz1I+adU= cloud.google.com/go/kms v1.26.0/go.mod h1:pHKOdFJm63hxBsiPkYtowZPltu9dW0MWvBa6IA4HM58= cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA= cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak= cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= cloud.google.com/go/storage v1.60.0 h1:oBfZrSOCimggVNz9Y/bXY35uUcts7OViubeddTTVzQ8= cloud.google.com/go/storage v1.60.0/go.mod h1:q+5196hXfejkctrnx+VYU8RKQr/L3c0cBIlrjmiAKE0= cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U= cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s= filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 h1:jWQK1GI+LeGGUKBADtcH2rRqPxYB1Ljwms5gFA2LqrM= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4/go.mod h1:8mwH4klAm9DUgR2EEHyEEAQlRDvLPyg5fQry3y+cDew= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.7.0 h1:4iB+IesclUXdP0ICgAabvq2FYLXrJWKx1fJQ+GxSo3Y= github.com/AzureAD/microsoft-authentication-library-for-go v1.7.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 h1:IEjq88XO4PuBDcvmjQJcQGg+w+UaafSy8G5Kcb5tBhI= github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5/go.mod h1:exZ0C/1emQJAw5tHOaUDyY1ycttqBAPcxuzf7QbY6ec= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 h1:DHa2U07rk8syqvCge0QIGMCE1WxGj9njT44GH7zNJLQ= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 h1:UnDZ/zFfG1JhH/DqxIZYU/1CUAlTUScoXD/LcM2Ykk8= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0/go.mod h1:IA1C1U7jO/ENqm/vhi7V9YYpBsp+IMyqNrEN94N7tVc= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0 h1:7t/qx5Ost0s0wbA/VDrByOooURhp+ikYwv20i9Y07TQ= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zdzgUIRvA= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 h1:0s6TxfCu2KHkkZPnBfsQ2y5qia0jl3MMrmBhu3nCOYk= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajg/form v1.7.1 h1:OsnBDzTkrWdrxvEnO68I72ZVGJGNaMwPhoAm0V+llgc= github.com/ajg/form v1.7.1/go.mod h1:HL757PzLyNkj5AIfptT6L+iGNeXTlnrr/oDePGc/y7Q= github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w= github.com/alexedwards/argon2id v1.0.0/go.mod h1:tYKkqIjzXvZdzPvADMWOEZ+l6+BD6CtBXMj5fnJppiw= github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964 h1:I9YN9WMo3SUh7p/4wKeNvD/IQla3U3SUa61U7ul+xM4= github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964/go.mod h1:eFiR01PwTcpbzXtdMces7zxg6utvFM5puiWHpWB8D/k= github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k= github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0= github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g= github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8= github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 h1:SwGMTMLIlvDNyhMteQ6r8IJSBPlRdXX5d4idhIGbkXA= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21/go.mod h1:UUxgWxofmOdAMuqEsSppbDtGKLfR04HGsD0HXzvhI1k= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 h1:qtJZ70afD3ISKWnoX3xB0J2otEqu3LqicRcDBqsj0hQ= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12/go.mod h1:v2pNpJbRNl4vEUWEh5ytQok0zACAKfdmKS51Hotc3pQ= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 h1:siU1A6xjUZ2N8zjTHSXFhB9L/2OY8Dqs0xXiLjF30jA= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20/go.mod h1:4TLZCmVJDM3FOu5P5TJP0zOlu9zWgDWU7aUxWbr+rcw= github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 h1:csi9NLpFZXb9fxY7rS1xVzgPRGMt7MSNWeQ6eo247kE= github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1/go.mod h1:qXVal5H0ChqXP63t6jze5LmFalc7+ZE7wOdLtZ0LCP0= github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow= github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE= github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o= github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA= github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU= github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bmatcuk/doublestar/v4 v4.10.0 h1:zU9WiOla1YA122oLM6i4EXvGW62DvKZVxIe6TYWexEs= github.com/bmatcuk/doublestar/v4 v4.10.0/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik= github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4= github.com/cockroachdb/cockroach-go/v2 v2.4.3 h1:LJO3K3jC5WXvMePRQSJE1NsIGoFGcEx1LW83W6RAlhw= github.com/cockroachdb/cockroach-go/v2 v2.4.3/go.mod h1:9U179XbCx4qFWtNhc7BiWLPfuyMVQ7qdAhfrwLz1vH0= github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U= github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f h1:S9JUlrOzjK58UKoLqqb40YLyVlt0bcIFtYrvnanV3zc= github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE= github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b h1:Y1tLiQ8fnxM5f3wiBjAXsHzHNwiY9BR+mXZA75nZwrs= github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE= github.com/eikenb/pipeat v0.0.0-20251030185646-385cd3c3e07b h1:G2Mm3YhlyjkFrNnvu5E6LtNcPJtggXL1i5ekDV4hDD4= github.com/eikenb/pipeat v0.0.0-20251030185646-385cd3c3e07b/go.mod h1:XccPiThW83W5pzeOCsJAylEUtWeH+3zQVwiO402FXXc= github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU= github.com/envoyproxy/go-control-plane/envoy v1.37.0 h1:u3riX6BoYRfF4Dr7dwSOroNfdSbEPe9Yyl09/B6wBrQ= github.com/envoyproxy/go-control-plane/envoy v1.37.0/go.mod h1:DReE9MMrmecPy+YvQOAOHNYMALuowAnbjjEMkkWOi6A= github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMDjpqGAGacLe2T0ds= github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fclairamb/ftpserverlib v0.30.0 h1:caB9sDn1Au//q0j2ev/icPn388qPuk4k1ajSvglDcMQ= github.com/fclairamb/ftpserverlib v0.30.0/go.mod h1:QmogtltTOgkihyKza0GNo37Mu4AEzbJ+sH6W9Y0MBIQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-acme/lego/v4 v4.33.0 h1:2KrRKieG+VczT4zvVXKAY7Tp/S3XLvh/QImofBALRAM= github.com/go-acme/lego/v4 v4.33.0/go.mod h1:lI2fZNdgeM/ymf9xQ9YKbgZm6MeDuf91UrohMQE4DhI= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18= github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE= github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-plugin v1.7.0 h1:YghfQH/0QmPNc/AZMTFE3ac8fipZyZECHdDPshfk+mA= github.com/hashicorp/go-plugin v1.7.0/go.mod h1:BExt6KEaIYx804z8k4gRzRLEvxKVb+kn0NMcihqOqb8= github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94= github.com/jhump/protoreflect v1.17.0/go.mod h1:h9+vUUL38jiBzck8ck+6G/aeMX8Z4QUY/NiJPwPNi+8= github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c= github.com/lithammer/shortuuid/v4 v4.2.0/go.mod h1:D5noHZ2oFw/YaKCfGy0YxyE7M0wMbezmMjPdhyEFe6Y= github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88 h1:PTw+yKnXcOFCR6+8hHTyWBeQ/P4Nb7dd4/0ohEcWQuM= github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.37 h1:3DOZp4cXis1cUIpCfXLtmlGolNLp2VEqhiB/PARNBIg= github.com/mattn/go-sqlite3 v1.14.37/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mhale/smtpd v0.8.3 h1:8j8YNXajksoSLZja3HdwvYVZPuJSqAxFsib3adzRRt8= github.com/mhale/smtpd v0.8.3/go.mod h1:MQl+y2hwIEQCXtNhe5+55n0GZOjSmeqORDIXbqUL3x4= github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= github.com/minio/sio v0.4.3 h1:JqyID1XM86KwBZox5RAdLD4MLPIDoCY2cke2CXCJCkg= github.com/minio/sio v0.4.3/go.mod h1:4ANoe4CCXqnt1FCiLM0+vlBUhhWZzVOhYCz0069KtFc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/oklog/run v1.2.0 h1:O8x3yXwah4A73hJdlrwo/2X6J62gE5qTMusH0dvz60E= github.com/oklog/run v1.2.0/go.mod h1:mgDbKRSwPhJfesJ4PntqFUbKQRZ50NgmZTSPlFA0YFk= github.com/otiai10/copy v1.14.1 h1:5/7E6qsUMBaH5AnQ0sSLzzTg1oTECmcCmT6lvF45Na8= github.com/otiai10/copy v1.14.1/go.mod h1:oQwrEDDOci3IM8dJF0d8+jnbfPDllW6vUjNc3DoZm9I= github.com/otiai10/mint v1.6.3 h1:87qsV/aw1F5as1eH1zS/yqHY85ANKVMgkDrf9rcxbQs= github.com/otiai10/mint v1.6.3/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4= github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU= github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc= github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo= github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4/go.mod h1:MnkX001NG75g3p8bhFycnyIjeQoOjGL6CEIsdE/nKSY= github.com/sftpgo/sdk v0.1.9 h1:onBWfibCt34xHeKC2KFYPZ1DBqXGl9um/cAw+AVdgzY= github.com/sftpgo/sdk v0.1.9/go.mod h1:ehimvlTP+XTEiE3t1CPwWx9n7+6A6OGvMGlZ7ouvKFk= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= github.com/shoenig/go-m1cpu v0.2.1 h1:yqRB4fvOge2+FyRXFkXqsyMoqPazv14Yyy+iyccT2E4= github.com/shoenig/go-m1cpu v0.2.1/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w= github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk= github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/studio-b12/gowebdav v0.12.0 h1:kFRtQECt8jmVAvA6RHBz3geXUGJHUZA6/IKpOVUs5kM= github.com/studio-b12/gowebdav v0.12.0/go.mod h1:bHA7t77X/QFExdeAnDzK6vKM34kEZAcE1OX4MfiwjkE= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA= github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU= github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I= github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ= github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8= github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k= github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a h1:XfF01GyP+0eWCaVp0y6rNN+kFp7pt9Da4UUYrJ5XPWA= github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a/go.mod h1:aXb8yZQEWo1XHGMf1qQfnb83GR/EJ2EBlwtUgAaNBoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/detectors/gcp v1.42.0 h1:kpt2PEJuOuqYkPcktfJqWWDjTEd/FNgrxcniL7kQrXQ= go.opentelemetry.io/contrib/detectors/gcp v1.42.0/go.mod h1:W9zQ439utxymRrXsUOzZbFX4JhLxXU4+ZnCt8GG7yA8= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg= go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0 h1:5gn2urDL/FBnK8OkCfD1j3/ER79rUuTYmCvlXBKeYL8= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0/go.mod h1:0fBG6ZJxhqByfFZDwSwpZGzJU671HkwpWaNe2t4VUPI= go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ= go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= gocloud.dev v0.45.0 h1:WknIK8IbRdmynDvara3Q7G6wQhmEiOGwpgJufbM39sY= gocloud.dev v0.45.0/go.mod h1:0kXKmkCLG6d31N7NyLZWzt7jDSQura9zD/mWgiB6THI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/api v0.272.0 h1:eLUQZGnAS3OHn31URRf9sAmRk3w2JjMx37d2k8AjJmA= google.golang.org/api v0.272.0/go.mod h1:wKjowi5LNJc5qarNvDCvNQBn3rVK8nSy6jg2SwRwzIA= google.golang.org/genproto v0.0.0-20260319171110-e3a33c96fb44 h1:5F2rCQQSavqKBEvXvkLt9Lc61ldIzythnt3+ucqJWG8= google.golang.org/genproto v0.0.0-20260319171110-e3a33c96fb44/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I= google.golang.org/genproto/googleapis/api v0.0.0-20260319171110-e3a33c96fb44 h1:3r0atjZtaSjEuUu8NNhKvvnFh8ad9PHWsN9VRotabFU= google.golang.org/genproto/googleapis/api v0.0.0-20260319171110-e3a33c96fb44/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y= google.golang.org/genproto/googleapis/rpc v0.0.0-20260319171110-e3a33c96fb44 h1:sRy++txmErSjyVWlIgQB5nB+U75+Di+AH7eEZ002B/s= google.golang.org/genproto/googleapis/rpc v0.0.0-20260319171110-e3a33c96fb44/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: init/com.github.drakkan.sftpgo.plist ================================================ Label com.github.drakkan.sftpgo EnvironmentVariables SFTPGO_CONFIG_DIR /usr/local/opt/sftpgo/etc SFTPGO_LOG_FILE_PATH /usr/local/opt/sftpgo/var/log/sftpgo.log SFTPGO_HTTPD__TEMPLATES_PATH /usr/local/opt/sftpgo/usr/share/templates SFTPGO_HTTPD__STATIC_FILES_PATH /usr/local/opt/sftpgo/usr/share/static SFTPGO_HTTPD__OPENAPI_PATH /usr/local/opt/sftpgo/usr/share/openapi SFTPGO_HTTPD__BACKUPS_PATH /usr/local/opt/sftpgo/var/lib/backups SFTPGO_DATA_PROVIDER__CREDENTIALS_PATH /usr/local/opt/sftpgo/var/lib/credentials WorkingDirectory /usr/local/opt/sftpgo/etc ProgramArguments /usr/local/opt/sftpgo/bin/sftpgo serve KeepAlive ThrottleInterval 10 ================================================ FILE: init/sftpgo.service ================================================ [Unit] Description=SFTPGo Server After=network.target [Service] User=sftpgo Group=sftpgo Type=simple WorkingDirectory=/etc/sftpgo RuntimeDirectory=sftpgo Environment=SFTPGO_CONFIG_DIR=/etc/sftpgo/ Environment=SFTPGO_LOG_FILE_PATH= EnvironmentFile=-/etc/sftpgo/sftpgo.env ExecStart=/usr/bin/sftpgo serve ExecReload=/bin/kill -s HUP $MAINPID LimitNOFILE=8192 KillMode=mixed PrivateTmp=true Restart=always RestartSec=10s NoNewPrivileges=yes PrivateDevices=yes DevicePolicy=closed ProtectSystem=true RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX AmbientCapabilities=CAP_NET_BIND_SERVICE [Install] WantedBy=multi-user.target ================================================ FILE: internal/acme/account.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package acme import ( "crypto" "github.com/go-acme/lego/v4/registration" ) type account struct { Email string `json:"email"` Registration *registration.Resource `json:"registration"` key crypto.PrivateKey } /** Implementation of the registration.User interface **/ // GetEmail returns the email address for the account. func (a *account) GetEmail() string { return a.Email } // GetRegistration returns the server registration. func (a *account) GetRegistration() *registration.Resource { return a.Registration } // GetPrivateKey returns the private account key. func (a *account) GetPrivateKey() crypto.PrivateKey { return a.key } /** End **/ ================================================ FILE: internal/acme/acme.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package acme provides automatic access to certificates from Let's Encrypt and any other ACME-based CA // The code here is largely coiped from https://github.com/go-acme/lego/tree/master/cmd // This package is intended to provide basic functionality for obtaining and renewing certificates // and implements the "HTTP-01" and "TLSALPN-01" challenge types. // For more advanced features use external tools such as "lego" package acme import ( "crypto" "crypto/x509" "encoding/json" "encoding/pem" "errors" "fmt" "math/rand" "net/url" "os" "path/filepath" "slices" "strconv" "strings" "time" "github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge/http01" "github.com/go-acme/lego/v4/challenge/tlsalpn01" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/log" "github.com/go-acme/lego/v4/providers/http/webroot" "github.com/go-acme/lego/v4/registration" "github.com/hashicorp/go-retryablehttp" "github.com/robfig/cron/v3" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/telemetry" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) const ( logSender = "acme" ) var ( config *Configuration initialConfig Configuration scheduler *cron.Cron logMode int supportedKeyTypes = []string{ string(certcrypto.EC256), string(certcrypto.EC384), string(certcrypto.RSA2048), string(certcrypto.RSA3072), string(certcrypto.RSA4096), string(certcrypto.RSA8192), } fnReloadHTTPDCerts func() error ) // SetReloadHTTPDCertsFn set the function to call to reload HTTPD certificates func SetReloadHTTPDCertsFn(fn func() error) { fnReloadHTTPDCerts = fn } // GetCertificates tries to obtain the certificates using the global configuration func GetCertificates() error { if config == nil { return errors.New("acme is disabled") } return config.getCertificates() } // GetCertificatesForConfig tries to obtain the certificates using the provided // configuration override. This is a NOOP if we already have certificates func GetCertificatesForConfig(c *dataprovider.ACMEConfigs, configDir string) error { if c.Domain == "" { acmeLog(logger.LevelDebug, "no domain configured, nothing to do") return nil } config := mergeConfig(getConfiguration(), c) if err := config.Initialize(configDir); err != nil { return err } hasCerts, err := config.hasCertificates(c.Domain) if err != nil { return fmt.Errorf("unable to check if we already have certificates for domain %q: %w", c.Domain, err) } if hasCerts { return nil } return config.getCertificates() } // GetHTTP01WebRoot returns the web root for HTTP-01 challenge func GetHTTP01WebRoot() string { return initialConfig.HTTP01Challenge.WebRoot } func mergeConfig(config Configuration, c *dataprovider.ACMEConfigs) Configuration { config.Domains = []string{c.Domain} config.Email = c.Email config.HTTP01Challenge.Port = c.HTTP01Challenge.Port config.TLSALPN01Challenge.Port = 0 return config } // getConfiguration returns the configuration set using config file and env vars func getConfiguration() Configuration { return initialConfig } func loadProviderConf(c Configuration) (Configuration, error) { configs, err := dataprovider.GetConfigs() if err != nil { return c, fmt.Errorf("unable to load config from provider: %w", err) } configs.SetNilsToEmpty() if configs.ACME.Domain == "" { return c, nil } return mergeConfig(c, configs.ACME), nil } // Initialize validates and set the configuration func Initialize(c Configuration, configDir string, checkRenew bool) error { config = nil initialConfig = c c, err := loadProviderConf(c) if err != nil { return err } util.CertsBasePath = "" setLogMode(checkRenew) if err := c.Initialize(configDir); err != nil { return err } if len(c.Domains) == 0 { return nil } util.CertsBasePath = c.CertsPath acmeLog(logger.LevelInfo, "configured domains: %+v, certs base path %q", c.Domains, c.CertsPath) config = &c if checkRenew { return startScheduler() } return nil } // HTTP01Challenge defines the configuration for HTTP-01 challenge type type HTTP01Challenge struct { Port int `json:"port" mapstructure:"port"` WebRoot string `json:"webroot" mapstructure:"webroot"` ProxyHeader string `json:"proxy_header" mapstructure:"proxy_header"` } func (c *HTTP01Challenge) isEnabled() bool { return c.Port > 0 || c.WebRoot != "" } func (c *HTTP01Challenge) validate() error { if !c.isEnabled() { return nil } if c.WebRoot != "" { if !filepath.IsAbs(c.WebRoot) { return fmt.Errorf("invalid HTTP-01 challenge web root, please set an absolute path") } _, err := os.Stat(c.WebRoot) if err != nil { return fmt.Errorf("invalid HTTP-01 challenge web root: %w", err) } } else { if c.Port > 65535 { return fmt.Errorf("invalid HTTP-01 challenge port: %d", c.Port) } } return nil } // TLSALPN01Challenge defines the configuration for TLSALPN-01 challenge type type TLSALPN01Challenge struct { Port int `json:"port" mapstructure:"port"` } func (c *TLSALPN01Challenge) isEnabled() bool { return c.Port > 0 } func (c *TLSALPN01Challenge) validate() error { if !c.isEnabled() { return nil } if c.Port > 65535 { return fmt.Errorf("invalid TLSALPN-01 challenge port: %d", c.Port) } return nil } // Configuration holds the ACME configuration type Configuration struct { Email string `json:"email" mapstructure:"email"` KeyType string `json:"key_type" mapstructure:"key_type"` CertsPath string `json:"certs_path" mapstructure:"certs_path"` CAEndpoint string `json:"ca_endpoint" mapstructure:"ca_endpoint"` // if a certificate is to be valid for multiple domains specify the names separated by commas, // for example: example.com,www.example.com Domains []string `json:"domains" mapstructure:"domains"` RenewDays int `json:"renew_days" mapstructure:"renew_days"` HTTP01Challenge HTTP01Challenge `json:"http01_challenge" mapstructure:"http01_challenge"` TLSALPN01Challenge TLSALPN01Challenge `json:"tls_alpn01_challenge" mapstructure:"tls_alpn01_challenge"` accountConfigPath string accountKeyPath string lockPath string tempDir string } // Initialize validates and initialize the configuration func (c *Configuration) Initialize(configDir string) error { c.checkDomains() if len(c.Domains) == 0 { acmeLog(logger.LevelInfo, "no domains configured, acme disabled") return nil } if c.Email == "" || !util.IsEmailValid(c.Email) { return util.NewI18nError( fmt.Errorf("invalid email address %q", c.Email), util.I18nErrorInvalidEmail, ) } if c.RenewDays < 1 { return fmt.Errorf("invalid number of days remaining before renewal: %d", c.RenewDays) } if !slices.Contains(supportedKeyTypes, c.KeyType) { return fmt.Errorf("invalid key type %q", c.KeyType) } caURL, err := url.Parse(c.CAEndpoint) if err != nil { return fmt.Errorf("invalid CA endopoint: %w", err) } if !util.IsFileInputValid(c.CertsPath) { return fmt.Errorf("invalid certs path %q", c.CertsPath) } if !filepath.IsAbs(c.CertsPath) { c.CertsPath = filepath.Join(configDir, c.CertsPath) } err = os.MkdirAll(c.CertsPath, 0700) if err != nil { return fmt.Errorf("unable to create certs path %q: %w", c.CertsPath, err) } c.tempDir = filepath.Join(c.CertsPath, "temp") err = os.MkdirAll(c.CertsPath, 0700) if err != nil { return fmt.Errorf("unable to create certs temp path %q: %w", c.tempDir, err) } serverPath := strings.NewReplacer(":", "_", "/", string(os.PathSeparator)).Replace(caURL.Host) accountPath := filepath.Join(c.CertsPath, serverPath) err = os.MkdirAll(accountPath, 0700) if err != nil { return fmt.Errorf("unable to create account path %q: %w", accountPath, err) } c.accountConfigPath = filepath.Join(accountPath, c.Email+".json") c.accountKeyPath = filepath.Join(accountPath, c.Email+".key") c.lockPath = filepath.Join(c.CertsPath, "lock") return c.validateChallenges() } func (c *Configuration) validateChallenges() error { if !c.HTTP01Challenge.isEnabled() && !c.TLSALPN01Challenge.isEnabled() { return fmt.Errorf("no challenge type defined") } if err := c.HTTP01Challenge.validate(); err != nil { return err } return c.TLSALPN01Challenge.validate() } func (c *Configuration) checkDomains() { var domains []string for _, domain := range c.Domains { domain = strings.TrimSpace(domain) if domain == "" { continue } if d, ok := isDomainValid(domain); ok { domains = append(domains, d) } } c.Domains = util.RemoveDuplicates(domains, true) } func (c *Configuration) setLockTime() error { lockTime := fmt.Sprintf("%v", util.GetTimeAsMsSinceEpoch(time.Now())) err := os.WriteFile(c.lockPath, []byte(lockTime), 0600) if err != nil { acmeLog(logger.LevelError, "unable to save lock time to %q: %v", c.lockPath, err) return fmt.Errorf("unable to save lock time: %w", err) } acmeLog(logger.LevelDebug, "lock time saved: %q", lockTime) return nil } func (c *Configuration) getLockTime() (time.Time, error) { content, err := os.ReadFile(c.lockPath) if err != nil { if os.IsNotExist(err) { acmeLog(logger.LevelDebug, "lock file %q not found", c.lockPath) return time.Time{}, nil } acmeLog(logger.LevelError, "unable to read lock file %q: %v", c.lockPath, err) return time.Time{}, err } msec, err := strconv.ParseInt(strings.TrimSpace(util.BytesToString(content)), 10, 64) if err != nil { acmeLog(logger.LevelError, "unable to parse lock time: %v", err) return time.Time{}, fmt.Errorf("unable to parse lock time: %w", err) } return util.GetTimeFromMsecSinceEpoch(msec), nil } func (c *Configuration) saveAccount(account *account) error { jsonBytes, err := json.MarshalIndent(account, "", "\t") if err != nil { return err } err = os.WriteFile(c.accountConfigPath, jsonBytes, 0600) if err != nil { acmeLog(logger.LevelError, "unable to save account to file %q: %v", c.accountConfigPath, err) return fmt.Errorf("unable to save account: %w", err) } return nil } func (c *Configuration) getAccount(privateKey crypto.PrivateKey) (account, error) { _, err := os.Stat(c.accountConfigPath) if err != nil && os.IsNotExist(err) { acmeLog(logger.LevelDebug, "account does not exist") return account{Email: c.Email, key: privateKey}, nil } var account account fileBytes, err := os.ReadFile(c.accountConfigPath) if err != nil { acmeLog(logger.LevelError, "unable to read account from file %q: %v", c.accountConfigPath, err) return account, fmt.Errorf("unable to read account from file: %w", err) } err = json.Unmarshal(fileBytes, &account) if err != nil { acmeLog(logger.LevelError, "invalid account file content: %v", err) return account, fmt.Errorf("unable to parse account file as JSON: %w", err) } account.key = privateKey if account.Registration == nil || account.Registration.Body.Status == "" { acmeLog(logger.LevelInfo, "couldn't load account but got a key. Try to look the account up") reg, err := c.tryRecoverRegistration(privateKey) if err != nil { acmeLog(logger.LevelError, "unable to look the account up: %v", err) return account, fmt.Errorf("unable to look the account up: %w", err) } account.Registration = reg err = c.saveAccount(&account) if err != nil { return account, err } } return account, nil } func (c *Configuration) loadPrivateKey() (crypto.PrivateKey, error) { keyBytes, err := os.ReadFile(c.accountKeyPath) if err != nil { acmeLog(logger.LevelError, "unable to read account key from file %q: %v", c.accountKeyPath, err) return nil, fmt.Errorf("unable to read account key: %w", err) } keyBlock, _ := pem.Decode(keyBytes) if keyBlock == nil { acmeLog(logger.LevelError, "unable to parse private key from file %q: pem decoding failed", c.accountKeyPath) return nil, errors.New("pem decoding failed") } var privateKey crypto.PrivateKey switch keyBlock.Type { case "RSA PRIVATE KEY": privateKey, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes) case "EC PRIVATE KEY": privateKey, err = x509.ParseECPrivateKey(keyBlock.Bytes) default: err = fmt.Errorf("unknown private key type %q", keyBlock.Type) } if err != nil { acmeLog(logger.LevelError, "unable to parse private key from file %q: %v", c.accountKeyPath, err) return privateKey, fmt.Errorf("unable to parse private key: %w", err) } return privateKey, nil } func (c *Configuration) generatePrivateKey() (crypto.PrivateKey, error) { privateKey, err := certcrypto.GeneratePrivateKey(certcrypto.KeyType(c.KeyType)) if err != nil { acmeLog(logger.LevelError, "unable to generate private key: %v", err) return nil, fmt.Errorf("unable to generate private key: %w", err) } certOut, err := os.Create(c.accountKeyPath) if err != nil { acmeLog(logger.LevelError, "unable to save private key to file %q: %v", c.accountKeyPath, err) return nil, fmt.Errorf("unable to save private key: %w", err) } defer certOut.Close() pemKey := certcrypto.PEMBlock(privateKey) err = pem.Encode(certOut, pemKey) if err != nil { acmeLog(logger.LevelError, "unable to encode private key: %v", err) return nil, fmt.Errorf("unable to encode private key: %w", err) } acmeLog(logger.LevelDebug, "new account private key generated") return privateKey, nil } func (c *Configuration) getPrivateKey() (crypto.PrivateKey, error) { _, err := os.Stat(c.accountKeyPath) if err != nil && os.IsNotExist(err) { acmeLog(logger.LevelDebug, "private key file %q does not exist, generating new private key", c.accountKeyPath) return c.generatePrivateKey() } acmeLog(logger.LevelDebug, "loading private key from file %q, stat error: %v", c.accountKeyPath, err) return c.loadPrivateKey() } func (c *Configuration) loadCertificatesForDomain(domain string) ([]*x509.Certificate, error) { domain = util.SanitizeDomain(domain) acmeLog(logger.LevelDebug, "loading certificates for domain %q", domain) content, err := os.ReadFile(filepath.Join(c.CertsPath, domain+".crt")) if err != nil { acmeLog(logger.LevelError, "unable to load certificates for domain %q: %v", domain, err) return nil, fmt.Errorf("unable to load certificates for domain %q: %w", domain, err) } certs, err := certcrypto.ParsePEMBundle(content) if err != nil { acmeLog(logger.LevelError, "unable to parse certificates for domain %q: %v", domain, err) return certs, fmt.Errorf("unable to parse certificates for domain %q: %w", domain, err) } return certs, nil } func (c *Configuration) needRenewal(x509Cert *x509.Certificate, domain string) bool { if x509Cert.IsCA { acmeLog(logger.LevelError, "certificate bundle starts with a CA certificate, cannot renew domain %v", domain) return false } notAfter := int(time.Until(x509Cert.NotAfter).Hours() / 24.0) if notAfter > c.RenewDays { acmeLog(logger.LevelDebug, "the certificate for domain %q expires in %d days, no renewal", domain, notAfter) return false } return true } func (c *Configuration) setup() (*account, *lego.Client, error) { privateKey, err := c.getPrivateKey() if err != nil { return nil, nil, err } account, err := c.getAccount(privateKey) if err != nil { return nil, nil, err } config := lego.NewConfig(&account) config.CADirURL = c.CAEndpoint config.Certificate.KeyType = certcrypto.KeyType(c.KeyType) config.Certificate.OverallRequestLimit = 6 config.UserAgent = version.GetServerVersion("/", false) retryClient := retryablehttp.NewClient() retryClient.Logger = &logger.LeveledLogger{Sender: "RetryableHTTPClient"} retryClient.RetryMax = 5 retryClient.HTTPClient = config.HTTPClient config.HTTPClient = retryClient.StandardClient() client, err := lego.NewClient(config) if err != nil { acmeLog(logger.LevelError, "unable to get ACME client: %v", err) return nil, nil, fmt.Errorf("unable to get ACME client: %w", err) } err = c.setupChalleges(client) if err != nil { return nil, nil, err } return &account, client, nil } func (c *Configuration) setupChalleges(client *lego.Client) error { client.Challenge.Remove(challenge.DNS01) if c.HTTP01Challenge.isEnabled() { if c.HTTP01Challenge.WebRoot != "" { acmeLog(logger.LevelDebug, "configuring HTTP-01 web root challenge, path %q", c.HTTP01Challenge.WebRoot) providerServer, err := webroot.NewHTTPProvider(c.HTTP01Challenge.WebRoot) if err != nil { acmeLog(logger.LevelError, "unable to create HTTP-01 web root challenge provider from path %q: %v", c.HTTP01Challenge.WebRoot, err) return fmt.Errorf("unable to create HTTP-01 web root challenge provider: %w", err) } err = client.Challenge.SetHTTP01Provider(providerServer) if err != nil { acmeLog(logger.LevelError, "unable to set HTTP-01 challenge provider: %v", err) return fmt.Errorf("unable to set HTTP-01 challenge provider: %w", err) } } else { acmeLog(logger.LevelDebug, "configuring HTTP-01 challenge, port %d", c.HTTP01Challenge.Port) providerServer := http01.NewProviderServer("", fmt.Sprintf("%d", c.HTTP01Challenge.Port)) if c.HTTP01Challenge.ProxyHeader != "" { acmeLog(logger.LevelDebug, "setting proxy header to \"%s\"", c.HTTP01Challenge.ProxyHeader) providerServer.SetProxyHeader(c.HTTP01Challenge.ProxyHeader) } err := client.Challenge.SetHTTP01Provider(providerServer) if err != nil { acmeLog(logger.LevelError, "unable to set HTTP-01 challenge provider: %v", err) return fmt.Errorf("unable to set HTTP-01 challenge provider: %w", err) } } } else { client.Challenge.Remove(challenge.HTTP01) } if c.TLSALPN01Challenge.isEnabled() { acmeLog(logger.LevelDebug, "configuring TLSALPN-01 challenge, port %d", c.TLSALPN01Challenge.Port) err := client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", fmt.Sprintf("%d", c.TLSALPN01Challenge.Port))) if err != nil { acmeLog(logger.LevelError, "unable to set TLSALPN-01 challenge provider: %v", err) return fmt.Errorf("unable to set TLSALPN-01 challenge provider: %w", err) } } else { client.Challenge.Remove(challenge.TLSALPN01) } return nil } func (c *Configuration) register(client *lego.Client) (*registration.Resource, error) { return client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) } func (c *Configuration) tryRecoverRegistration(privateKey crypto.PrivateKey) (*registration.Resource, error) { config := lego.NewConfig(&account{key: privateKey}) config.CADirURL = c.CAEndpoint config.UserAgent = version.GetServerVersion("/", false) retryClient := retryablehttp.NewClient() retryClient.Logger = &logger.LeveledLogger{Sender: "RetryableHTTPClient"} retryClient.RetryMax = 5 retryClient.HTTPClient = config.HTTPClient config.HTTPClient = retryClient.StandardClient() client, err := lego.NewClient(config) if err != nil { acmeLog(logger.LevelError, "unable to get the ACME client: %v", err) return nil, err } return client.Registration.ResolveAccountByKey() } func (c *Configuration) getCrtPath(domain string) string { return filepath.Join(c.CertsPath, domain+".crt") } func (c *Configuration) getKeyPath(domain string) string { return filepath.Join(c.CertsPath, domain+".key") } func (c *Configuration) getResourcePath(domain string) string { return filepath.Join(c.CertsPath, domain+".json") } func (c *Configuration) obtainAndSaveCertificate(client *lego.Client, domain string) error { domains := getDomains(domain) acmeLog(logger.LevelInfo, "requesting certificates for domains %+v", domains) request := certificate.ObtainRequest{ Domains: domains, Bundle: true, MustStaple: false, PreferredChain: "", AlwaysDeactivateAuthorizations: false, } cert, err := client.Certificate.Obtain(request) if err != nil { acmeLog(logger.LevelError, "unable to obtain certificates for domains %+v: %v", domains, err) return fmt.Errorf("unable to obtain certificates: %w", err) } domain = util.SanitizeDomain(domain) err = os.WriteFile(c.getCrtPath(domain), cert.Certificate, 0600) if err != nil { acmeLog(logger.LevelError, "unable to save certificate for domain %s: %v", domain, err) return fmt.Errorf("unable to save certificate: %w", err) } err = os.WriteFile(c.getKeyPath(domain), cert.PrivateKey, 0600) if err != nil { acmeLog(logger.LevelError, "unable to save private key for domain %s: %v", domain, err) return fmt.Errorf("unable to save private key: %w", err) } jsonBytes, err := json.MarshalIndent(cert, "", "\t") if err != nil { acmeLog(logger.LevelError, "unable to marshal certificate resources for domain %v: %v", domain, err) return err } err = os.WriteFile(c.getResourcePath(domain), jsonBytes, 0600) if err != nil { acmeLog(logger.LevelError, "unable to save certificate resources for domain %v: %v", domain, err) return fmt.Errorf("unable to save certificate resources: %w", err) } acmeLog(logger.LevelInfo, "certificates for domains %+v saved", domains) return nil } // hasCertificates returns true if certificates for the specified domain has already been issued func (c *Configuration) hasCertificates(domain string) (bool, error) { domain = util.SanitizeDomain(domain) if _, err := os.Stat(c.getCrtPath(domain)); err != nil { if os.IsNotExist(err) { return false, nil } return false, err } if _, err := os.Stat(c.getKeyPath(domain)); err != nil { if os.IsNotExist(err) { return false, nil } return false, err } return true, nil } // getCertificates tries to obtain the certificates for the configured domains func (c *Configuration) getCertificates() error { account, client, err := c.setup() if err != nil { return err } if account.Registration == nil { reg, err := c.register(client) if err != nil { acmeLog(logger.LevelError, "unable to register account: %v", err) return fmt.Errorf("unable to register account: %w", err) } account.Registration = reg err = c.saveAccount(account) if err != nil { return err } } for _, domain := range c.Domains { err = c.obtainAndSaveCertificate(client, domain) if err != nil { return err } } return nil } func (c *Configuration) notifyCertificateRenewal(domain string, err error) { if domain == "" { domain = strings.Join(c.Domains, ",") } params := common.EventParams{ Name: domain, Event: "Certificate renewal", Timestamp: time.Now(), } if err != nil { params.Status = 2 params.AddError(err) } else { params.Status = 1 } common.HandleCertificateEvent(params) } func (c *Configuration) renewCertificates() error { lockTime, err := c.getLockTime() if err != nil { return err } acmeLog(logger.LevelDebug, "certificate renew lock time %v", lockTime) if lockTime.Add(-30*time.Second).Before(time.Now()) && lockTime.Add(5*time.Minute).After(time.Now()) { acmeLog(logger.LevelInfo, "certificate renew skipped, lock time too close: %v", lockTime) return nil } err = c.setLockTime() if err != nil { c.notifyCertificateRenewal("", err) return err } account, client, err := c.setup() if err != nil { c.notifyCertificateRenewal("", err) return err } if account.Registration == nil { acmeLog(logger.LevelError, "cannot renew certificates, your account is not registered") err = errors.New("cannot renew certificates, your account is not registered") c.notifyCertificateRenewal("", err) return err } var errRenew error needReload := false for _, domain := range c.Domains { certificates, err := c.loadCertificatesForDomain(domain) if err != nil { c.notifyCertificateRenewal(domain, err) errRenew = err continue } cert := certificates[0] if !c.needRenewal(cert, domain) { continue } err = c.obtainAndSaveCertificate(client, domain) if err != nil { c.notifyCertificateRenewal(domain, err) errRenew = err } else { c.notifyCertificateRenewal(domain, nil) needReload = true } } if needReload { // at least one certificate has been renewed, sends a reload to all services that may be using certificates err = ftpd.ReloadCertificateMgr() acmeLog(logger.LevelInfo, "ftpd certificate manager reloaded , error: %v", err) if fnReloadHTTPDCerts != nil { err = fnReloadHTTPDCerts() acmeLog(logger.LevelInfo, "httpd certificates manager reloaded , error: %v", err) } err = webdavd.ReloadCertificateMgr() acmeLog(logger.LevelInfo, "webdav certificates manager reloaded , error: %v", err) err = telemetry.ReloadCertificateMgr() acmeLog(logger.LevelInfo, "telemetry certificates manager reloaded , error: %v", err) } return errRenew } func isDomainValid(domain string) (string, bool) { isValid := false for d := range strings.SplitSeq(domain, ",") { d = strings.TrimSpace(d) if d != "" { isValid = true break } } return domain, isValid } func getDomains(domain string) []string { var domains []string delimiter := "," if !strings.Contains(domain, ",") && strings.Contains(domain, " ") { delimiter = " " } for d := range strings.SplitSeq(domain, delimiter) { d = strings.TrimSpace(d) if d != "" { domains = append(domains, d) } } return util.RemoveDuplicates(domains, false) } func stopScheduler() { if scheduler != nil { scheduler.Stop() scheduler = nil } } func startScheduler() error { stopScheduler() randSecs := rand.Intn(59) scheduler = cron.New(cron.WithLocation(time.UTC), cron.WithLogger(cron.DiscardLogger)) _, err := scheduler.AddFunc(fmt.Sprintf("@every 12h0m%ds", randSecs), renewCertificates) if err != nil { return fmt.Errorf("unable to schedule certificates renewal: %w", err) } acmeLog(logger.LevelInfo, "starting scheduler, initial certificates check in %d seconds", randSecs) initialTimer := time.NewTimer(time.Duration(randSecs) * time.Second) go func() { <-initialTimer.C renewCertificates() }() scheduler.Start() return nil } func renewCertificates() { if config != nil { if err := config.renewCertificates(); err != nil { acmeLog(logger.LevelError, "unable to renew certificates: %v", err) } } } func setLogMode(checkRenew bool) { if checkRenew { logMode = 1 } else { logMode = 2 } log.Logger = &logger.LegoAdapter{ LogToConsole: logMode != 1, } } func acmeLog(level logger.LogLevel, format string, v ...any) { if logMode == 1 { logger.Log(level, logSender, "", format, v...) } else { switch level { case logger.LevelDebug: logger.DebugToConsole(format, v...) case logger.LevelInfo: logger.InfoToConsole(format, v...) case logger.LevelWarn: logger.WarnToConsole(format, v...) default: logger.ErrorToConsole(format, v...) } } } ================================================ FILE: internal/bundle/bundle.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build bundle package bundle import ( "embed" "fmt" "io/fs" "net/http" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("+bundle") } //go:embed templates/* var templatesFs embed.FS //go:embed static/* var staticFs embed.FS //go:embed openapi/* var openapiFs embed.FS // GetTemplatesFs returns the embedded filesystem with the SFTPGo templates func GetTemplatesFs() embed.FS { return templatesFs } // GetStaticFs return the http Filesystem with the embedded static files func GetStaticFs() http.FileSystem { fsys, err := fs.Sub(staticFs, "static") if err != nil { err = fmt.Errorf("unable to get embedded filesystem for static files: %w", err) panic(err) } return http.FS(fsys) } // GetOpenAPIFs return the http Filesystem with the embedded static files func GetOpenAPIFs() http.FileSystem { fsys, err := fs.Sub(openapiFs, "openapi") if err != nil { err = fmt.Errorf("unable to get embedded filesystem for OpenAPI files: %w", err) panic(err) } return http.FS(fsys) } ================================================ FILE: internal/cmd/acme.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "os" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( acmeCmd = &cobra.Command{ Use: "acme", Short: "Obtain TLS certificates from ACME-based CAs like Let's Encrypt", } acmeRunCmd = &cobra.Command{ Use: "run", Short: "Register your account and obtain certificates", Long: `This command must be run to obtain TLS certificates the first time or every time you add a new domain to your configuration file. Certificates are saved in the configured "certs_path". After this initial step, the certificates are automatically checked and renewed by the SFTPGo service `, Run: func(_ *cobra.Command, _ []string) { logger.DisableLogger() logger.EnableConsoleLogger(zerolog.DebugLevel) configDir = util.CleanDirInput(configDir) err := config.LoadConfig(configDir, configFile) if err != nil { logger.ErrorToConsole("Unable to initialize ACME, config load error: %v", err) return } kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("unable to initialize KMS: %v", err) os.Exit(1) } if config.HasKMSPlugin() { if err := plugin.Initialize(config.GetPluginsConfig(), "debug"); err != nil { logger.ErrorToConsole("unable to initialize plugin system: %v", err) os.Exit(1) } registerSignals() defer plugin.Handler.Cleanup() } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.ErrorToConsole("Unable to initialize MFA: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, false) if err != nil { logger.ErrorToConsole("error initializing data provider: %v", err) os.Exit(1) } acmeConfig := config.GetACMEConfig() err = acme.Initialize(acmeConfig, configDir, false) if err != nil { logger.ErrorToConsole("Unable to initialize ACME configuration: %v", err) os.Exit(1) } if err = acme.GetCertificates(); err != nil { logger.ErrorToConsole("Cannot get certificates: %v", err) os.Exit(1) } }, } ) func init() { addConfigFlags(acmeRunCmd) acmeCmd.AddCommand(acmeRunCmd) rootCmd.AddCommand(acmeCmd) } ================================================ FILE: internal/cmd/gen.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import "github.com/spf13/cobra" var genCmd = &cobra.Command{ Use: "gen", Short: "A collection of useful generators", } func init() { rootCmd.AddCommand(genCmd) } ================================================ FILE: internal/cmd/gencompletion.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "os" "github.com/spf13/cobra" ) var genCompletionCmd = &cobra.Command{ Use: "completion [bash|zsh|fish|powershell]", Short: "Generate the autocompletion script for the specified shell", Long: `Generate the autocompletion script for sftpgo for the specified shell. See each sub-command's help for details on how to use the generated script. `, } var genCompletionBashCmd = &cobra.Command{ Use: "bash", Short: "Generate the autocompletion script for bash", Long: `Generate the autocompletion script for the bash shell. This script depends on the 'bash-completion' package. If it is not installed already, you can install it via your OS's package manager. To load completions in your current shell session: $ source <(sftpgo gen completion bash) To load completions for every new session, execute once: Linux: $ sudo sftpgo gen completion bash > /usr/share/bash-completion/completions/sftpgo MacOS: $ sudo sftpgo gen completion bash > /usr/local/etc/bash_completion.d/sftpgo You will need to start a new shell for this setup to take effect. `, DisableFlagsInUseLine: true, RunE: func(cmd *cobra.Command, _ []string) error { return cmd.Root().GenBashCompletionV2(os.Stdout, true) }, } var genCompletionZshCmd = &cobra.Command{ Use: "zsh", Short: "Generate the autocompletion script for zsh", Long: `Generate the autocompletion script for the zsh shell. If shell completion is not already enabled in your environment you will need to enable it. You can execute the following once: $ echo "autoload -U compinit; compinit" >> ~/.zshrc To load completions for every new session, execute once: Linux: $ sftpgo gen completion zsh > > "${fpath[1]}/_sftpgo" macOS: $ sudo sftpgo gen completion zsh > /usr/local/share/zsh/site-functions/_sftpgo You will need to start a new shell for this setup to take effect. `, DisableFlagsInUseLine: true, RunE: func(cmd *cobra.Command, _ []string) error { return cmd.Root().GenZshCompletion(os.Stdout) }, } var genCompletionFishCmd = &cobra.Command{ Use: "fish", Short: "Generate the autocompletion script for fish", Long: `Generate the autocompletion script for the fish shell. To load completions in your current shell session: $ sftpgo gen completion fish | source To load completions for every new session, execute once: $ sftpgo gen completion fish > ~/.config/fish/completions/sftpgo.fish You will need to start a new shell for this setup to take effect. `, DisableFlagsInUseLine: true, RunE: func(cmd *cobra.Command, _ []string) error { return cmd.Root().GenFishCompletion(os.Stdout, true) }, } var genCompletionPowerShellCmd = &cobra.Command{ Use: "powershell", Short: "Generate the autocompletion script for powershell", Long: `Generate the autocompletion script for powershell. To load completions in your current shell session: PS C:\> sftpgo gen completion powershell | Out-String | Invoke-Expression To load completions for every new session, add the output of the above command to your powershell profile. `, DisableFlagsInUseLine: true, RunE: func(cmd *cobra.Command, _ []string) error { return cmd.Root().GenPowerShellCompletionWithDesc(os.Stdout) }, } func init() { genCompletionCmd.AddCommand(genCompletionBashCmd) genCompletionCmd.AddCommand(genCompletionZshCmd) genCompletionCmd.AddCommand(genCompletionFishCmd) genCompletionCmd.AddCommand(genCompletionPowerShellCmd) genCmd.AddCommand(genCompletionCmd) } ================================================ FILE: internal/cmd/genman.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "errors" "fmt" "io/fs" "os" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/spf13/cobra/doc" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/version" ) var ( manDir string genManCmd = &cobra.Command{ Use: "man", Short: "Generate man pages for sftpgo", Long: `This command automatically generates up-to-date man pages of SFTPGo's command-line interface. By default, it creates the man page files in the "man" directory under the current directory. `, Run: func(cmd *cobra.Command, _ []string) { logger.DisableLogger() logger.EnableConsoleLogger(zerolog.DebugLevel) if _, err := os.Stat(manDir); errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(manDir, os.ModePerm) if err != nil { logger.WarnToConsole("Unable to generate man page files: %v", err) os.Exit(1) } } header := &doc.GenManHeader{ Section: "1", Manual: "SFTPGo Manual", Source: fmt.Sprintf("SFTPGo %v", version.Get().Version), } cmd.Root().DisableAutoGenTag = true err := doc.GenManTree(cmd.Root(), header, manDir) if err != nil { logger.WarnToConsole("Unable to generate man page files: %v", err) os.Exit(1) } }, } ) func init() { genManCmd.Flags().StringVarP(&manDir, "dir", "d", "man", "The directory to write the man pages") genCmd.AddCommand(genManCmd) } ================================================ FILE: internal/cmd/initprovider.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "os" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( initProviderCmd = &cobra.Command{ Use: "initprovider", Short: "Initialize and/or updates the configured data provider", Long: `This command reads the data provider connection details from the specified configuration file and creates the initial structure or update the existing one, as needed. Some data providers such as bolt and memory does not require an initialization but they could require an update to the existing data after upgrading SFTPGo. For SQLite/bolt providers the database file will be auto-created if missing. For PostgreSQL and MySQL providers you need to create the configured database, this command will create/update the required tables as needed. To initialize/update the data provider from the configuration directory simply use: $ sftpgo initprovider Any defined action is ignored. Please take a look at the usage below to customize the options.`, Run: func(_ *cobra.Command, _ []string) { logger.DisableLogger() logger.EnableConsoleLogger(zerolog.DebugLevel) configDir = util.CleanDirInput(configDir) err := config.LoadConfig(configDir, configFile) if err != nil { logger.ErrorToConsole("Unable to initialize data provider, config load error: %v", err) return } kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("Unable to initialize KMS: %v", err) os.Exit(1) } if config.HasKMSPlugin() { if err := plugin.Initialize(config.GetPluginsConfig(), "debug"); err != nil { logger.ErrorToConsole("unable to initialize plugin system: %v", err) os.Exit(1) } registerSignals() defer plugin.Handler.Cleanup() } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.ErrorToConsole("Unable to initialize MFA: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() // ignore actions providerConf.Actions.Hook = "" providerConf.Actions.ExecuteFor = nil providerConf.Actions.ExecuteOn = nil logger.InfoToConsole("Initializing provider: %q config file: %q", providerConf.Driver, viper.ConfigFileUsed()) err = dataprovider.InitializeDatabase(providerConf, configDir) switch err { case nil: logger.InfoToConsole("Data provider successfully initialized/updated") case dataprovider.ErrNoInitRequired: logger.InfoToConsole("%v", err.Error()) default: logger.ErrorToConsole("Unable to initialize/update the data provider: %v", err) os.Exit(1) } if providerConf.Driver != dataprovider.MemoryDataProviderName && loadDataFrom != "" { if err := common.Initialize(config.GetCommonConfig(), providerConf.GetShared()); err != nil { logger.ErrorToConsole("%v", err) os.Exit(1) } service := service.Service{ LoadDataFrom: loadDataFrom, LoadDataMode: loadDataMode, LoadDataQuotaScan: loadDataQuotaScan, LoadDataClean: loadDataClean, } if err = service.LoadInitialData(); err != nil { logger.ErrorToConsole("Cannot load initial data: %v", err) os.Exit(1) } } }, } ) func init() { rootCmd.AddCommand(initProviderCmd) addConfigFlags(initProviderCmd) addBaseLoadDataFlags(initProviderCmd) } ================================================ FILE: internal/cmd/install_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "fmt" "os" "strconv" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( installCmd = &cobra.Command{ Use: "install", Short: "Install SFTPGo as Windows Service", Long: `To install the SFTPGo Windows Service with the default values for the command line flags simply use: sftpgo service install Please take a look at the usage below to customize the startup options`, Run: func(_ *cobra.Command, _ []string) { s := service.Service{ ConfigDir: util.CleanDirInput(configDir), ConfigFile: configFile, LogFilePath: logFilePath, LogMaxSize: logMaxSize, LogMaxBackups: logMaxBackups, LogMaxAge: logMaxAge, LogCompress: logCompress, LogLevel: logLevel, LogUTCTime: logUTCTime, Shutdown: make(chan bool), } winService := service.WindowsService{ Service: s, } serviceArgs := []string{"service", "start"} customFlags := getCustomServeFlags() if len(customFlags) > 0 { serviceArgs = append(serviceArgs, customFlags...) } err := winService.Install(serviceArgs...) if err != nil { fmt.Printf("Error installing service: %v\r\n", err) os.Exit(1) } else { fmt.Printf("Service installed!\r\n") } }, } ) func init() { serviceCmd.AddCommand(installCmd) addServeFlags(installCmd) } func getCustomServeFlags() []string { result := []string{} if configDir != defaultConfigDir { configDir = util.CleanDirInput(configDir) result = append(result, "--"+configDirFlag) result = append(result, configDir) } if configFile != defaultConfigFile { result = append(result, "--"+configFileFlag) result = append(result, configFile) } if logFilePath != defaultLogFile { result = append(result, "--"+logFilePathFlag) result = append(result, logFilePath) } if logMaxSize != defaultLogMaxSize { result = append(result, "--"+logMaxSizeFlag) result = append(result, strconv.Itoa(logMaxSize)) } if logMaxBackups != defaultLogMaxBackup { result = append(result, "--"+logMaxBackupFlag) result = append(result, strconv.Itoa(logMaxBackups)) } if logMaxAge != defaultLogMaxAge { result = append(result, "--"+logMaxAgeFlag) result = append(result, strconv.Itoa(logMaxAge)) } if logLevel != defaultLogLevel { result = append(result, "--"+logLevelFlag) result = append(result, logLevel) } if logUTCTime != defaultLogUTCTime { result = append(result, "--"+logUTCTimeFlag+"=true") } if logCompress != defaultLogCompress { result = append(result, "--"+logCompressFlag+"=true") } if graceTime != defaultGraceTime { result = append(result, "--"+graceTimeFlag) result = append(result, strconv.Itoa(graceTime)) } return result } ================================================ FILE: internal/cmd/ping.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "fmt" "net/http" "os" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) func getHealthzURLFromBindings(bindings []httpd.Binding) string { for _, b := range bindings { if b.Port > 0 && b.IsValid() { var url string if b.EnableHTTPS { url = "https://" } else { url = "http://" } if b.Address == "" { url += "127.0.0.1" } else { url += b.Address } url += fmt.Sprintf(":%d", b.Port) url += "/healthz" return url } } return "" } var ( pingCmd = &cobra.Command{ Use: "ping", Short: "Issues an health check to SFTPGo", Long: `This command is only useful in environments where system commands like "curl", "wget" and similar are not available. Checks over UNIX domain sockets are not supported`, Run: func(_ *cobra.Command, _ []string) { logger.DisableLogger() logger.EnableConsoleLogger(zerolog.DebugLevel) configDir = util.CleanDirInput(configDir) err := config.LoadConfig(configDir, configFile) if err != nil { logger.WarnToConsole("Unable to load configuration: %v", err) os.Exit(1) } httpConfig := config.GetHTTPConfig() err = httpConfig.Initialize(configDir) if err != nil { logger.ErrorToConsole("error initializing http client: %v", err) os.Exit(1) } telemetryConfig := config.GetTelemetryConfig() var url string if telemetryConfig.BindPort > 0 { if telemetryConfig.CertificateFile != "" && telemetryConfig.CertificateKeyFile != "" { url += "https://" } else { url += "http://" } if telemetryConfig.BindAddress == "" { url += "127.0.0.1" } else { url += telemetryConfig.BindAddress } url += fmt.Sprintf(":%d", telemetryConfig.BindPort) url += "/healthz" } if url == "" { httpdConfig := config.GetHTTPDConfig() url = getHealthzURLFromBindings(httpdConfig.Bindings) } if url == "" { logger.ErrorToConsole("no suitable configuration found, please enable the telemetry server or REST API over HTTP/S") os.Exit(1) } logger.DebugToConsole("Health Check URL %q", url) resp, err := httpclient.RetryableGet(url) if err != nil { logger.ErrorToConsole("Unable to connect to SFTPGo: %v", err) os.Exit(1) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.ErrorToConsole("Unexpected status code %d", resp.StatusCode) os.Exit(1) } logger.InfoToConsole("OK") }, } ) func init() { addConfigFlags(pingCmd) rootCmd.AddCommand(pingCmd) } ================================================ FILE: internal/cmd/portable.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !noportable package cmd import ( "fmt" "os" "path" "path/filepath" "strings" "github.com/sftpgo/sdk" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) var ( directoryToServe string portableSFTPDPort int portableUsername string portablePassword string portablePasswordFile string portableStartDir string portableLogFile string portableLogLevel string portableLogUTCTime bool portablePublicKeys []string portablePermissions []string portableSSHCommands []string portableAllowedPatterns []string portableDeniedPatterns []string portableFsProvider string portableS3Bucket string portableS3Region string portableS3AccessKey string portableS3AccessSecret string portableS3RoleARN string portableS3Endpoint string portableS3StorageClass string portableS3ACL string portableS3KeyPrefix string portableS3ULPartSize int portableS3ULConcurrency int portableS3ForcePathStyle bool portableS3SkipTLSVerify bool portableGCSBucket string portableGCSCredentialsFile string portableGCSAutoCredentials int portableGCSStorageClass string portableGCSKeyPrefix string portableFTPDPort int portableFTPSCert string portableFTPSKey string portableWebDAVPort int portableWebDAVCert string portableWebDAVKey string portableHTTPPort int portableHTTPSCert string portableHTTPSKey string portableAzContainer string portableAzAccountName string portableAzAccountKey string portableAzEndpoint string portableAzAccessTier string portableAzSASURL string portableAzKeyPrefix string portableAzULPartSize int portableAzULConcurrency int portableAzDLPartSize int portableAzDLConcurrency int portableAzUseEmulator bool portableCryptPassphrase string portableSFTPEndpoint string portableSFTPUsername string portableSFTPPassword string portableSFTPPrivateKeyPath string portableSFTPFingerprints []string portableSFTPPrefix string portableSFTPDisableConcurrentReads bool portableSFTPDBufferSize int64 portableCmd = &cobra.Command{ Use: "portable", Short: "Serve a single directory/account", Long: `To serve the current working directory with auto generated credentials simply use: $ sftpgo portable Please take a look at the usage below to customize the serving parameters`, Run: func(_ *cobra.Command, _ []string) { portableDir := directoryToServe fsProvider := dataprovider.GetProviderFromValue(convertFsProvider()) if !filepath.IsAbs(portableDir) { if fsProvider == sdk.LocalFilesystemProvider { portableDir, _ = filepath.Abs(portableDir) } else { portableDir = os.TempDir() } } permissions := make(map[string][]string) permissions["/"] = portablePermissions portableGCSCredentials := "" if fsProvider == sdk.GCSFilesystemProvider && portableGCSCredentialsFile != "" { contents, err := getFileContents(portableGCSCredentialsFile) if err != nil { fmt.Printf("Unable to get GCS credentials: %v\n", err) os.Exit(1) } portableGCSCredentials = contents portableGCSAutoCredentials = 0 } portableSFTPPrivateKey := "" if fsProvider == sdk.SFTPFilesystemProvider && portableSFTPPrivateKeyPath != "" { contents, err := getFileContents(portableSFTPPrivateKeyPath) if err != nil { fmt.Printf("Unable to get SFTP private key: %v\n", err) os.Exit(1) } portableSFTPPrivateKey = contents } if portableFTPDPort >= 0 && portableFTPSCert != "" && portableFTPSKey != "" { keyPairs := []common.TLSKeyPair{ { Cert: portableFTPSCert, Key: portableFTPSKey, ID: common.DefaultTLSKeyPaidID, }, } _, err := common.NewCertManager(keyPairs, filepath.Clean(defaultConfigDir), "FTP portable") if err != nil { fmt.Printf("Unable to load FTPS key pair, cert file %q key file %q error: %v\n", portableFTPSCert, portableFTPSKey, err) os.Exit(1) } } if portableWebDAVPort >= 0 && portableWebDAVCert != "" && portableWebDAVKey != "" { keyPairs := []common.TLSKeyPair{ { Cert: portableWebDAVCert, Key: portableWebDAVKey, ID: common.DefaultTLSKeyPaidID, }, } _, err := common.NewCertManager(keyPairs, filepath.Clean(defaultConfigDir), "WebDAV portable") if err != nil { fmt.Printf("Unable to load WebDAV key pair, cert file %q key file %q error: %v\n", portableWebDAVCert, portableWebDAVKey, err) os.Exit(1) } } if portableHTTPPort >= 0 && portableHTTPSCert != "" && portableHTTPSKey != "" { keyPairs := []common.TLSKeyPair{ { Cert: portableHTTPSCert, Key: portableHTTPSKey, ID: common.DefaultTLSKeyPaidID, }, } _, err := common.NewCertManager(keyPairs, filepath.Clean(defaultConfigDir), "HTTP portable") if err != nil { fmt.Printf("Unable to load HTTPS key pair, cert file %q key file %q error: %v\n", portableHTTPSCert, portableHTTPSKey, err) os.Exit(1) } } pwd := portablePassword if portablePasswordFile != "" { content, err := os.ReadFile(portablePasswordFile) if err != nil { fmt.Printf("Unable to read password file %q: %v", portablePasswordFile, err) os.Exit(1) } pwd = strings.TrimSpace(util.BytesToString(content)) } service.SetGraceTime(graceTime) service := service.Service{ ConfigDir: util.CleanDirInput(configDir), ConfigFile: configFile, LogFilePath: portableLogFile, LogMaxSize: defaultLogMaxSize, LogMaxBackups: defaultLogMaxBackup, LogMaxAge: defaultLogMaxAge, LogCompress: defaultLogCompress, LogLevel: portableLogLevel, LogUTCTime: portableLogUTCTime, Shutdown: make(chan bool), PortableMode: 1, PortableUser: dataprovider.User{ BaseUser: sdk.BaseUser{ Username: portableUsername, Password: pwd, PublicKeys: portablePublicKeys, Permissions: permissions, HomeDir: portableDir, Status: 1, }, Filters: dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ FilePatterns: parsePatternsFilesFilters(), StartDirectory: portableStartDir, }, }, FsConfig: vfs.Filesystem{ Provider: fsProvider, S3Config: vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: portableS3Bucket, Region: portableS3Region, AccessKey: portableS3AccessKey, RoleARN: portableS3RoleARN, Endpoint: portableS3Endpoint, StorageClass: portableS3StorageClass, ACL: portableS3ACL, KeyPrefix: portableS3KeyPrefix, UploadPartSize: int64(portableS3ULPartSize), UploadConcurrency: portableS3ULConcurrency, ForcePathStyle: portableS3ForcePathStyle, SkipTLSVerify: portableS3SkipTLSVerify, }, AccessSecret: kms.NewPlainSecret(portableS3AccessSecret), }, GCSConfig: vfs.GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ Bucket: portableGCSBucket, AutomaticCredentials: portableGCSAutoCredentials, StorageClass: portableGCSStorageClass, KeyPrefix: portableGCSKeyPrefix, }, Credentials: kms.NewPlainSecret(portableGCSCredentials), }, AzBlobConfig: vfs.AzBlobFsConfig{ BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ Container: portableAzContainer, AccountName: portableAzAccountName, Endpoint: portableAzEndpoint, AccessTier: portableAzAccessTier, KeyPrefix: portableAzKeyPrefix, UseEmulator: portableAzUseEmulator, UploadPartSize: int64(portableAzULPartSize), UploadConcurrency: portableAzULConcurrency, DownloadPartSize: int64(portableAzDLPartSize), DownloadConcurrency: portableAzDLConcurrency, }, AccountKey: kms.NewPlainSecret(portableAzAccountKey), SASURL: kms.NewPlainSecret(portableAzSASURL), }, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(portableCryptPassphrase), }, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: portableSFTPEndpoint, Username: portableSFTPUsername, Fingerprints: portableSFTPFingerprints, Prefix: portableSFTPPrefix, DisableCouncurrentReads: portableSFTPDisableConcurrentReads, BufferSize: portableSFTPDBufferSize, }, Password: kms.NewPlainSecret(portableSFTPPassword), PrivateKey: kms.NewPlainSecret(portableSFTPPrivateKey), KeyPassphrase: kms.NewEmptySecret(), }, }, }, } err := service.StartPortableMode(portableSFTPDPort, portableFTPDPort, portableWebDAVPort, portableHTTPPort, portableSSHCommands, portableFTPSCert, portableFTPSKey, portableWebDAVCert, portableWebDAVKey, portableHTTPSCert, portableHTTPSKey) if err == nil { service.Wait() if service.Error == nil { os.Exit(0) } } os.Exit(1) }, } ) func init() { version.AddFeature("+portable") portableCmd.Flags().StringVarP(&directoryToServe, "directory", "d", ".", `Path to the directory to serve. This can be an absolute path or a path relative to the current directory `) portableCmd.Flags().StringVar(&portableStartDir, "start-directory", "/", `Alternate start directory. This is a virtual path not a filesystem path`) portableCmd.Flags().IntVarP(&portableSFTPDPort, "sftpd-port", "s", 0, `0 means a random unprivileged port, < 0 disabled`) portableCmd.Flags().IntVar(&portableFTPDPort, "ftpd-port", -1, `0 means a random unprivileged port, < 0 disabled`) portableCmd.Flags().IntVar(&portableWebDAVPort, "webdav-port", -1, `0 means a random unprivileged port, < 0 disabled`) portableCmd.Flags().IntVar(&portableHTTPPort, "httpd-port", -1, `0 means a random unprivileged port, < 0 disabled`) portableCmd.Flags().StringSliceVar(&portableSSHCommands, "ssh-commands", sftpd.GetDefaultSSHCommands(), `SSH commands to enable. "*" means any supported SSH command including scp `) portableCmd.Flags().StringVarP(&portableUsername, "username", "u", "", `Leave empty to use an auto generated value`) portableCmd.Flags().StringVarP(&portablePassword, "password", "p", "", `Leave empty to use an auto generated value`) portableCmd.Flags().StringVar(&portablePasswordFile, "password-file", "", `Read the password from the specified file path. Leave empty to use an auto generated value`) portableCmd.Flags().StringVarP(&portableLogFile, logFilePathFlag, "l", "", "Leave empty to disable logging") portableCmd.Flags().StringVar(&portableLogLevel, logLevelFlag, defaultLogLevel, `Set the log level. Supported values: debug, info, warn, error. `) portableCmd.Flags().BoolVar(&portableLogUTCTime, logUTCTimeFlag, false, "Use UTC time for logging") portableCmd.Flags().StringSliceVarP(&portablePublicKeys, "public-key", "k", []string{}, "") portableCmd.Flags().StringSliceVarP(&portablePermissions, "permissions", "g", []string{"list", "download"}, `User's permissions. "*" means any permission`) portableCmd.Flags().StringArrayVar(&portableAllowedPatterns, "allowed-patterns", []string{}, `Allowed file patterns case insensitive. The format is: /dir::pattern1,pattern2. For example: "/somedir::*.jpg,a*b?.png"`) portableCmd.Flags().StringArrayVar(&portableDeniedPatterns, "denied-patterns", []string{}, `Denied file patterns case insensitive. The format is: /dir::pattern1,pattern2. For example: "/somedir::*.jpg,a*b?.png"`) portableCmd.Flags().StringVarP(&portableFsProvider, "fs-provider", "f", "osfs", `osfs => local filesystem (legacy value: 0) s3fs => AWS S3 compatible (legacy: 1) gcsfs => Google Cloud Storage (legacy: 2) azblobfs => Azure Blob Storage (legacy: 3) cryptfs => Encrypted local filesystem (legacy: 4) sftpfs => SFTP (legacy: 5)`) portableCmd.Flags().StringVar(&portableS3Bucket, "s3-bucket", "", "") portableCmd.Flags().StringVar(&portableS3Region, "s3-region", "", "") portableCmd.Flags().StringVar(&portableS3AccessKey, "s3-access-key", "", "") portableCmd.Flags().StringVar(&portableS3AccessSecret, "s3-access-secret", "", "") portableCmd.Flags().StringVar(&portableS3RoleARN, "s3-role-arn", "", "") portableCmd.Flags().StringVar(&portableS3Endpoint, "s3-endpoint", "", "") portableCmd.Flags().StringVar(&portableS3StorageClass, "s3-storage-class", "", "") portableCmd.Flags().StringVar(&portableS3ACL, "s3-acl", "", "") portableCmd.Flags().StringVar(&portableS3KeyPrefix, "s3-key-prefix", "", `Allows to restrict access to the virtual folder identified by this prefix and its contents`) portableCmd.Flags().IntVar(&portableS3ULPartSize, "s3-upload-part-size", 5, `The buffer size for multipart uploads (MB)`) portableCmd.Flags().IntVar(&portableS3ULConcurrency, "s3-upload-concurrency", 2, `How many parts are uploaded in parallel`) portableCmd.Flags().BoolVar(&portableS3ForcePathStyle, "s3-force-path-style", false, `Force path style bucket URL`) portableCmd.Flags().BoolVar(&portableS3SkipTLSVerify, "s3-skip-tls-verify", false, `If enabled the S3 client accepts any TLS certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. `) portableCmd.Flags().StringVar(&portableGCSBucket, "gcs-bucket", "", "") portableCmd.Flags().StringVar(&portableGCSStorageClass, "gcs-storage-class", "", "") portableCmd.Flags().StringVar(&portableGCSKeyPrefix, "gcs-key-prefix", "", `Allows to restrict access to the virtual folder identified by this prefix and its contents`) portableCmd.Flags().StringVar(&portableGCSCredentialsFile, "gcs-credentials-file", "", `Google Cloud Storage JSON credentials file`) portableCmd.Flags().IntVar(&portableGCSAutoCredentials, "gcs-automatic-credentials", 1, `0 means explicit credentials using a JSON credentials file, 1 automatic `) portableCmd.Flags().StringVar(&portableFTPSCert, "ftpd-cert", "", "Path to the certificate file for FTPS") portableCmd.Flags().StringVar(&portableFTPSKey, "ftpd-key", "", "Path to the key file for FTPS") portableCmd.Flags().StringVar(&portableWebDAVCert, "webdav-cert", "", `Path to the certificate file for WebDAV over HTTPS`) portableCmd.Flags().StringVar(&portableWebDAVKey, "webdav-key", "", `Path to the key file for WebDAV over HTTPS`) portableCmd.Flags().StringVar(&portableHTTPSCert, "httpd-cert", "", `Path to the certificate file for WebClient over HTTPS`) portableCmd.Flags().StringVar(&portableHTTPSKey, "httpd-key", "", `Path to the key file for WebClient over HTTPS`) portableCmd.Flags().StringVar(&portableAzContainer, "az-container", "", "") portableCmd.Flags().StringVar(&portableAzAccountName, "az-account-name", "", "") portableCmd.Flags().StringVar(&portableAzAccountKey, "az-account-key", "", "") portableCmd.Flags().StringVar(&portableAzSASURL, "az-sas-url", "", `Shared access signature URL`) portableCmd.Flags().StringVar(&portableAzEndpoint, "az-endpoint", "", `Leave empty to use the default: "blob.core.windows.net"`) portableCmd.Flags().StringVar(&portableAzAccessTier, "az-access-tier", "", `Leave empty to use the default container setting`) portableCmd.Flags().StringVar(&portableAzKeyPrefix, "az-key-prefix", "", `Allows to restrict access to the virtual folder identified by this prefix and its contents`) portableCmd.Flags().IntVar(&portableAzULPartSize, "az-upload-part-size", 5, `The buffer size for multipart uploads (MB)`) portableCmd.Flags().IntVar(&portableAzULConcurrency, "az-upload-concurrency", 5, `How many parts are uploaded in parallel`) portableCmd.Flags().IntVar(&portableAzDLPartSize, "az-download-part-size", 5, `The buffer size for multipart downloads (MB)`) portableCmd.Flags().IntVar(&portableAzDLConcurrency, "az-download-concurrency", 5, `How many parts are downloaded in parallel`) portableCmd.Flags().BoolVar(&portableAzUseEmulator, "az-use-emulator", false, "") portableCmd.Flags().StringVar(&portableCryptPassphrase, "crypto-passphrase", "", `Passphrase for encryption/decryption`) portableCmd.Flags().StringVar(&portableSFTPEndpoint, "sftp-endpoint", "", `SFTP endpoint as host:port for SFTP provider`) portableCmd.Flags().StringVar(&portableSFTPUsername, "sftp-username", "", `SFTP user for SFTP provider`) portableCmd.Flags().StringVar(&portableSFTPPassword, "sftp-password", "", `SFTP password for SFTP provider`) portableCmd.Flags().StringVar(&portableSFTPPrivateKeyPath, "sftp-key-path", "", `SFTP private key path for SFTP provider`) portableCmd.Flags().StringSliceVar(&portableSFTPFingerprints, "sftp-fingerprints", []string{}, `SFTP fingerprints to verify remote host key for SFTP provider`) portableCmd.Flags().StringVar(&portableSFTPPrefix, "sftp-prefix", "", `SFTP prefix allows restrict all operations to a given path within the remote SFTP server`) portableCmd.Flags().BoolVar(&portableSFTPDisableConcurrentReads, "sftp-disable-concurrent-reads", false, `Concurrent reads are safe to use and disabling them will degrade performance. Disable for read once servers`) portableCmd.Flags().Int64Var(&portableSFTPDBufferSize, "sftp-buffer-size", 0, `The size of the buffer (in MB) to use for transfers. By enabling buffering, the reads and writes, from/to the remote SFTP server, are split in multiple concurrent requests and this allows data to be transferred at a faster rate, over high latency networks, by overlapping round-trip times`) portableCmd.Flags().IntVar(&graceTime, graceTimeFlag, 0, `This grace time defines the number of seconds allowed for existing transfers to get completed before shutting down. A graceful shutdown is triggered by an interrupt signal. `) addConfigFlags(portableCmd) rootCmd.AddCommand(portableCmd) } func parsePatternsFilesFilters() []sdk.PatternsFilter { var patterns []sdk.PatternsFilter for _, val := range portableAllowedPatterns { p, exts := getPatternsFilterValues(strings.TrimSpace(val)) if p != "" { patterns = append(patterns, sdk.PatternsFilter{ Path: path.Clean(p), AllowedPatterns: exts, DeniedPatterns: []string{}, }) } } for _, val := range portableDeniedPatterns { p, exts := getPatternsFilterValues(strings.TrimSpace(val)) if p != "" { found := false for index, e := range patterns { if path.Clean(e.Path) == path.Clean(p) { patterns[index].DeniedPatterns = append(patterns[index].DeniedPatterns, exts...) found = true break } } if !found { patterns = append(patterns, sdk.PatternsFilter{ Path: path.Clean(p), AllowedPatterns: []string{}, DeniedPatterns: exts, }) } } } return patterns } func getPatternsFilterValues(value string) (string, []string) { if strings.Contains(value, "::") { dirExts := strings.Split(value, "::") if len(dirExts) > 1 { dir := strings.TrimSpace(dirExts[0]) exts := []string{} for e := range strings.SplitSeq(dirExts[1], ",") { cleanedExt := strings.TrimSpace(e) if cleanedExt != "" { exts = append(exts, cleanedExt) } } if dir != "" && len(exts) > 0 { return dir, exts } } } return "", nil } func getFileContents(name string) (string, error) { fi, err := os.Stat(name) if err != nil { return "", err } if fi.Size() > 1048576 { return "", fmt.Errorf("%q is too big %v/1048576 bytes", name, fi.Size()) } contents, err := os.ReadFile(name) if err != nil { return "", err } return util.BytesToString(contents), nil } func convertFsProvider() string { switch portableFsProvider { case "osfs", "6": // httpfs (6) is not supported in portable mode, so return the default return "0" case "s3fs": return "1" case "gcsfs": return "2" case "azblobfs": return "3" case "cryptfs": return "4" case "sftpfs": return "5" default: return portableFsProvider } } ================================================ FILE: internal/cmd/portable_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build noportable package cmd import "github.com/drakkan/sftpgo/v2/internal/version" func init() { version.AddFeature("-portable") } ================================================ FILE: internal/cmd/reload_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "fmt" "os" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/service" ) var ( reloadCmd = &cobra.Command{ Use: "reload", Short: "Reload the SFTPGo Windows Service sending a \"paramchange\" request", Run: func(_ *cobra.Command, _ []string) { s := service.WindowsService{ Service: service.Service{ Shutdown: make(chan bool), }, } err := s.Reload() if err != nil { fmt.Printf("Error sending reload signal: %v\r\n", err) os.Exit(1) } else { fmt.Printf("Reload signal sent!\r\n") } }, } ) func init() { serviceCmd.AddCommand(reloadCmd) } ================================================ FILE: internal/cmd/resetprovider.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "bufio" "os" "strings" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( resetProviderForce bool resetProviderCmd = &cobra.Command{ Use: "resetprovider", Short: "Reset the configured provider, any data will be lost", Long: `This command reads the data provider connection details from the specified configuration file and resets the provider by deleting all data and schemas. This command is not supported for the memory provider. Please take a look at the usage below to customize the options.`, Run: func(_ *cobra.Command, _ []string) { logger.DisableLogger() logger.EnableConsoleLogger(zerolog.DebugLevel) configDir = util.CleanDirInput(configDir) err := config.LoadConfig(configDir, configFile) if err != nil { logger.WarnToConsole("Unable to load configuration: %v", err) os.Exit(1) } kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("unable to initialize KMS: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() if !resetProviderForce { logger.WarnToConsole("You are about to delete all the SFTPGo data for provider %q, config file: %q", providerConf.Driver, viper.ConfigFileUsed()) logger.WarnToConsole("Are you sure? (Y/n)") reader := bufio.NewReader(os.Stdin) answer, err := reader.ReadString('\n') if err != nil { logger.ErrorToConsole("unable to read your answer: %v", err) os.Exit(1) } if strings.ToUpper(strings.TrimSpace(answer)) != "Y" { logger.InfoToConsole("command aborted") os.Exit(1) } } logger.InfoToConsole("Resetting provider: %q, config file: %q", providerConf.Driver, viper.ConfigFileUsed()) err = dataprovider.ResetDatabase(providerConf, configDir) if err != nil { logger.WarnToConsole("Error resetting provider: %v", err) os.Exit(1) } logger.InfoToConsole("Tha data provider was successfully reset") }, } ) func init() { addConfigFlags(resetProviderCmd) resetProviderCmd.Flags().BoolVar(&resetProviderForce, "force", false, `reset the provider without asking for confirmation`) rootCmd.AddCommand(resetProviderCmd) } ================================================ FILE: internal/cmd/resetpwd.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "bytes" "fmt" "os" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/spf13/viper" "golang.org/x/term" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( resetPwdAdmin string resetPwdCmd = &cobra.Command{ Use: "resetpwd", Short: "Reset the password for the specified administrator", Long: `This command reads the data provider connection details from the specified configuration file and resets the password for the specified administrator. Two-factor authentication is also disabled. This command is not supported for the memory provider. For embedded providers like bolt and SQLite you should stop the running SFTPGo instance to avoid database corruption. Please take a look at the usage below to customize the options.`, Run: func(_ *cobra.Command, _ []string) { logger.DisableLogger() logger.EnableConsoleLogger(zerolog.DebugLevel) configDir = util.CleanDirInput(configDir) err := config.LoadConfig(configDir, configFile) if err != nil { logger.WarnToConsole("Unable to load configuration: %v", err) os.Exit(1) } kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("unable to initialize KMS: %v", err) os.Exit(1) } if config.HasKMSPlugin() { if err := plugin.Initialize(config.GetPluginsConfig(), "debug"); err != nil { logger.ErrorToConsole("unable to initialize plugin system: %v", err) os.Exit(1) } registerSignals() defer plugin.Handler.Cleanup() } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.ErrorToConsole("Unable to initialize MFA: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() if providerConf.Driver == dataprovider.MemoryDataProviderName { logger.ErrorToConsole("memory provider is not supported") os.Exit(1) } logger.InfoToConsole("Initializing provider: %q config file: %q", providerConf.Driver, viper.ConfigFileUsed()) err = dataprovider.Initialize(providerConf, configDir, false) if err != nil { logger.ErrorToConsole("Unable to initialize data provider: %v", err) os.Exit(1) } admin, err := dataprovider.AdminExists(resetPwdAdmin) if err != nil { logger.ErrorToConsole("Unable to get admin %q: %v", resetPwdAdmin, err) os.Exit(1) } fmt.Printf("Enter Password: ") pwd, err := term.ReadPassword(int(os.Stdin.Fd())) if err != nil { logger.ErrorToConsole("Unable to read the password: %v", err) os.Exit(1) } fmt.Println("") fmt.Printf("Confirm Password: ") confirmPwd, err := term.ReadPassword(int(os.Stdin.Fd())) if err != nil { logger.ErrorToConsole("Unable to read the password: %v", err) os.Exit(1) } fmt.Println("") if !bytes.Equal(pwd, confirmPwd) { logger.ErrorToConsole("Passwords do not match") os.Exit(1) } admin.Password = string(pwd) admin.Filters.TOTPConfig.Enabled = false if err := dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSystem, "", ""); err != nil { logger.ErrorToConsole("Unable to update password: %v", err) os.Exit(1) } logger.InfoToConsole("Password updated for admin %q", resetPwdAdmin) }, } ) func init() { addConfigFlags(resetPwdCmd) resetPwdCmd.Flags().StringVar(&resetPwdAdmin, "admin", "", `Administrator username whose password to reset`) resetPwdCmd.MarkFlagRequired("admin") //nolint:errcheck rootCmd.AddCommand(resetPwdCmd) } ================================================ FILE: internal/cmd/revertprovider.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "os" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( revertProviderTargetVersion int revertProviderCmd = &cobra.Command{ Use: "revertprovider", Short: "Revert the configured data provider to a previous version", Long: `This command reads the data provider connection details from the specified configuration file and restore the provider schema and/or data to a previous version. This command is not supported for the memory provider. Please take a look at the usage below to customize the options.`, Run: func(_ *cobra.Command, _ []string) { logger.DisableLogger() logger.EnableConsoleLogger(zerolog.DebugLevel) if revertProviderTargetVersion != 33 { logger.WarnToConsole("Unsupported target version, 33 is the only supported one") os.Exit(1) } configDir = util.CleanDirInput(configDir) err := config.LoadConfig(configDir, configFile) if err != nil { logger.WarnToConsole("Unable to load configuration: %v", err) os.Exit(1) } kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("unable to initialize KMS: %v", err) os.Exit(1) } if config.HasKMSPlugin() { if err := plugin.Initialize(config.GetPluginsConfig(), "debug"); err != nil { logger.ErrorToConsole("unable to initialize plugin system: %v", err) os.Exit(1) } registerSignals() defer plugin.Handler.Cleanup() } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.ErrorToConsole("Unable to initialize MFA: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() logger.InfoToConsole("Reverting provider: %q config file: %q target version %d", providerConf.Driver, viper.ConfigFileUsed(), revertProviderTargetVersion) err = dataprovider.RevertDatabase(providerConf, configDir, revertProviderTargetVersion) if err != nil { logger.WarnToConsole("Error reverting provider: %v", err) os.Exit(1) } logger.InfoToConsole("Data provider successfully reverted") }, } ) func init() { addConfigFlags(revertProviderCmd) revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 33, `33 means the version supported in v2.7.x`) rootCmd.AddCommand(revertProviderCmd) } ================================================ FILE: internal/cmd/root.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package cmd provides Command Line Interface support package cmd import ( "fmt" "os" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( configDirFlag = "config-dir" configDirKey = "config_dir" configFileFlag = "config-file" configFileKey = "config_file" logFilePathFlag = "log-file-path" logFilePathKey = "log_file_path" logMaxSizeFlag = "log-max-size" logMaxSizeKey = "log_max_size" logMaxBackupFlag = "log-max-backups" logMaxBackupKey = "log_max_backups" logMaxAgeFlag = "log-max-age" logMaxAgeKey = "log_max_age" logCompressFlag = "log-compress" logCompressKey = "log_compress" logLevelFlag = "log-level" logLevelKey = "log_level" logUTCTimeFlag = "log-utc-time" logUTCTimeKey = "log_utc_time" loadDataFromFlag = "loaddata-from" loadDataFromKey = "loaddata_from" loadDataModeFlag = "loaddata-mode" loadDataModeKey = "loaddata_mode" loadDataQuotaScanFlag = "loaddata-scan" loadDataQuotaScanKey = "loaddata_scan" loadDataCleanFlag = "loaddata-clean" loadDataCleanKey = "loaddata_clean" graceTimeFlag = "grace-time" graceTimeKey = "grace_time" defaultConfigDir = "." defaultConfigFile = "" defaultLogFile = "sftpgo.log" defaultLogMaxSize = 10 defaultLogMaxBackup = 5 defaultLogMaxAge = 28 defaultLogCompress = false defaultLogLevel = "debug" defaultLogUTCTime = false defaultLoadDataFrom = "" defaultLoadDataMode = 1 defaultLoadDataQuotaScan = 0 defaultLoadDataClean = false defaultGraceTime = 0 ) var ( configDir string configFile string logFilePath string logMaxSize int logMaxBackups int logMaxAge int logCompress bool logLevel string logUTCTime bool loadDataFrom string loadDataMode int loadDataQuotaScan int loadDataClean bool graceTime int rootCmd = &cobra.Command{ Use: "sftpgo", Short: "Full-featured and highly configurable file transfer server", } ) func init() { rootCmd.CompletionOptions.DisableDefaultCmd = true rootCmd.Flags().BoolP("version", "v", false, "") rootCmd.Version = version.GetAsString() rootCmd.SetVersionTemplate(`{{printf "SFTPGo "}}{{printf "%s" .Version}} `) } // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { if err := rootCmd.Execute(); err != nil { fmt.Println(err) os.Exit(1) } } func addConfigFlags(cmd *cobra.Command) { viper.SetDefault(configDirKey, defaultConfigDir) viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") //nolint:errcheck // err is not nil only if the key to bind is missing cmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey), `Location of the config dir. This directory is used as the base for files with a relative path, e.g. the private keys for the SFTP server or the database file if you use a file-based data provider. The configuration file, if not explicitly set, is looked for in this dir. We support reading from JSON, TOML, YAML, HCL, envfile and Java properties config files. The default config file name is "sftpgo" and therefore "sftpgo.json", "sftpgo.yaml" and so on are searched. This flag can be set using SFTPGO_CONFIG_DIR env var too.`) viper.BindPFlag(configDirKey, cmd.Flags().Lookup(configDirFlag)) //nolint:errcheck viper.SetDefault(configFileKey, defaultConfigFile) viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") //nolint:errcheck cmd.Flags().StringVar(&configFile, configFileFlag, viper.GetString(configFileKey), `Path to SFTPGo configuration file. This flag explicitly defines the path, name and extension of the config file. If must be an absolute path or a path relative to the configuration directory. The specified file name must have a supported extension (JSON, YAML, TOML, HCL or Java properties). This flag can be set using SFTPGO_CONFIG_FILE env var too.`) viper.BindPFlag(configFileKey, cmd.Flags().Lookup(configFileFlag)) //nolint:errcheck } func addBaseLoadDataFlags(cmd *cobra.Command) { viper.SetDefault(loadDataFromKey, defaultLoadDataFrom) viper.BindEnv(loadDataFromKey, "SFTPGO_LOADDATA_FROM") //nolint:errcheck cmd.Flags().StringVar(&loadDataFrom, loadDataFromFlag, viper.GetString(loadDataFromKey), `Load users and folders from this file. The file must be specified as absolute path and it must contain a backup obtained using the "dumpdata" REST API or compatible content. This flag can be set using SFTPGO_LOADDATA_FROM env var too. `) viper.BindPFlag(loadDataFromKey, cmd.Flags().Lookup(loadDataFromFlag)) //nolint:errcheck viper.SetDefault(loadDataModeKey, defaultLoadDataMode) viper.BindEnv(loadDataModeKey, "SFTPGO_LOADDATA_MODE") //nolint:errcheck cmd.Flags().IntVar(&loadDataMode, loadDataModeFlag, viper.GetInt(loadDataModeKey), `Restore mode for data to load: 0 - new users are added, existing users are updated 1 - New users are added, existing users are not modified This flag can be set using SFTPGO_LOADDATA_MODE env var too. `) viper.BindPFlag(loadDataModeKey, cmd.Flags().Lookup(loadDataModeFlag)) //nolint:errcheck viper.SetDefault(loadDataCleanKey, defaultLoadDataClean) viper.BindEnv(loadDataCleanKey, "SFTPGO_LOADDATA_CLEAN") //nolint:errcheck cmd.Flags().BoolVar(&loadDataClean, loadDataCleanFlag, viper.GetBool(loadDataCleanKey), `Determine if the loaddata-from file should be removed after a successful load. This flag can be set using SFTPGO_LOADDATA_CLEAN env var too. (default "false") `) viper.BindPFlag(loadDataCleanKey, cmd.Flags().Lookup(loadDataCleanFlag)) //nolint:errcheck } func addServeFlags(cmd *cobra.Command) { addConfigFlags(cmd) viper.SetDefault(logFilePathKey, defaultLogFile) viper.BindEnv(logFilePathKey, "SFTPGO_LOG_FILE_PATH") //nolint:errcheck cmd.Flags().StringVarP(&logFilePath, logFilePathFlag, "l", viper.GetString(logFilePathKey), `Location for the log file. Leave empty to write logs to the standard output. This flag can be set using SFTPGO_LOG_FILE_PATH env var too. `) viper.BindPFlag(logFilePathKey, cmd.Flags().Lookup(logFilePathFlag)) //nolint:errcheck viper.SetDefault(logMaxSizeKey, defaultLogMaxSize) viper.BindEnv(logMaxSizeKey, "SFTPGO_LOG_MAX_SIZE") //nolint:errcheck cmd.Flags().IntVarP(&logMaxSize, logMaxSizeFlag, "s", viper.GetInt(logMaxSizeKey), `Maximum size in megabytes of the log file before it gets rotated. This flag can be set using SFTPGO_LOG_MAX_SIZE env var too. It is unused if log-file-path is empty. `) viper.BindPFlag(logMaxSizeKey, cmd.Flags().Lookup(logMaxSizeFlag)) //nolint:errcheck viper.SetDefault(logMaxBackupKey, defaultLogMaxBackup) viper.BindEnv(logMaxBackupKey, "SFTPGO_LOG_MAX_BACKUPS") //nolint:errcheck cmd.Flags().IntVarP(&logMaxBackups, "log-max-backups", "b", viper.GetInt(logMaxBackupKey), `Maximum number of old log files to retain. This flag can be set using SFTPGO_LOG_MAX_BACKUPS env var too. It is unused if log-file-path is empty.`) viper.BindPFlag(logMaxBackupKey, cmd.Flags().Lookup(logMaxBackupFlag)) //nolint:errcheck viper.SetDefault(logMaxAgeKey, defaultLogMaxAge) viper.BindEnv(logMaxAgeKey, "SFTPGO_LOG_MAX_AGE") //nolint:errcheck cmd.Flags().IntVarP(&logMaxAge, "log-max-age", "a", viper.GetInt(logMaxAgeKey), `Maximum number of days to retain old log files. This flag can be set using SFTPGO_LOG_MAX_AGE env var too. It is unused if log-file-path is empty. `) viper.BindPFlag(logMaxAgeKey, cmd.Flags().Lookup(logMaxAgeFlag)) //nolint:errcheck viper.SetDefault(logCompressKey, defaultLogCompress) viper.BindEnv(logCompressKey, "SFTPGO_LOG_COMPRESS") //nolint:errcheck cmd.Flags().BoolVarP(&logCompress, logCompressFlag, "z", viper.GetBool(logCompressKey), `Determine if the rotated log files should be compressed using gzip. This flag can be set using SFTPGO_LOG_COMPRESS env var too. It is unused if log-file-path is empty. `) viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) //nolint:errcheck viper.SetDefault(logLevelKey, defaultLogLevel) viper.BindEnv(logLevelKey, "SFTPGO_LOG_LEVEL") //nolint:errcheck cmd.Flags().StringVar(&logLevel, logLevelFlag, viper.GetString(logLevelKey), `Set the log level. Supported values: debug, info, warn, error. This flag can be set using SFTPGO_LOG_LEVEL env var too. `) viper.BindPFlag(logLevelKey, cmd.Flags().Lookup(logLevelFlag)) //nolint:errcheck viper.SetDefault(logUTCTimeKey, defaultLogUTCTime) viper.BindEnv(logUTCTimeKey, "SFTPGO_LOG_UTC_TIME") //nolint:errcheck cmd.Flags().BoolVar(&logUTCTime, logUTCTimeFlag, viper.GetBool(logUTCTimeKey), `Use UTC time for logging. This flag can be set using SFTPGO_LOG_UTC_TIME env var too. `) viper.BindPFlag(logUTCTimeKey, cmd.Flags().Lookup(logUTCTimeFlag)) //nolint:errcheck addBaseLoadDataFlags(cmd) viper.SetDefault(loadDataQuotaScanKey, defaultLoadDataQuotaScan) viper.BindEnv(loadDataQuotaScanKey, "SFTPGO_LOADDATA_QUOTA_SCAN") //nolint:errcheck cmd.Flags().IntVar(&loadDataQuotaScan, loadDataQuotaScanFlag, viper.GetInt(loadDataQuotaScanKey), `Quota scan mode after data load: 0 - no quota scan 1 - scan quota 2 - scan quota if the user has quota restrictions This flag can be set using SFTPGO_LOADDATA_QUOTA_SCAN env var too. (default 0)`) viper.BindPFlag(loadDataQuotaScanKey, cmd.Flags().Lookup(loadDataQuotaScanFlag)) //nolint:errcheck viper.SetDefault(graceTimeKey, defaultGraceTime) viper.BindEnv(graceTimeKey, "SFTPGO_GRACE_TIME") //nolint:errcheck cmd.Flags().IntVar(&graceTime, graceTimeFlag, viper.GetInt(graceTimeKey), `Graceful shutdown is an option to initiate a shutdown without abrupt cancellation of the currently ongoing client-initiated transfer sessions. This grace time defines the number of seconds allowed for existing transfers to get completed before shutting down. A graceful shutdown is triggered by an interrupt signal. This flag can be set using SFTPGO_GRACE_TIME env var too. 0 means disabled. (default 0)`) viper.BindPFlag(graceTimeKey, cmd.Flags().Lookup(graceTimeFlag)) //nolint:errcheck } ================================================ FILE: internal/cmd/rotatelogs_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "fmt" "os" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/service" ) var ( rotateLogCmd = &cobra.Command{ Use: "rotatelogs", Short: "Signal to the running service to rotate the logs", Run: func(_ *cobra.Command, _ []string) { s := service.WindowsService{ Service: service.Service{ Shutdown: make(chan bool), }, } err := s.RotateLogFile() if err != nil { fmt.Printf("Error sending rotate log file signal to the service: %v\r\n", err) os.Exit(1) } else { fmt.Printf("Rotate log file signal sent!\r\n") } }, } ) func init() { serviceCmd.AddCommand(rotateLogCmd) } ================================================ FILE: internal/cmd/serve.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "os" "path/filepath" "strconv" "strings" "github.com/spf13/cobra" "github.com/subosito/gotenv" "github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( envFileMaxSize = 1048576 ) var ( serveCmd = &cobra.Command{ Use: "serve", Short: "Start the SFTPGo service", Long: `To start the SFTPGo with the default values for the command line flags simply use: $ sftpgo serve Please take a look at the usage below to customize the startup options`, Run: func(_ *cobra.Command, _ []string) { configDir := util.CleanDirInput(configDir) checkServeParamsFromEnvFiles(configDir) service.SetGraceTime(graceTime) service := service.Service{ ConfigDir: configDir, ConfigFile: configFile, LogFilePath: logFilePath, LogMaxSize: logMaxSize, LogMaxBackups: logMaxBackups, LogMaxAge: logMaxAge, LogCompress: logCompress, LogLevel: logLevel, LogUTCTime: logUTCTime, LoadDataFrom: loadDataFrom, LoadDataMode: loadDataMode, LoadDataQuotaScan: loadDataQuotaScan, LoadDataClean: loadDataClean, Shutdown: make(chan bool), } if err := service.Start(); err == nil { service.Wait() if service.Error == nil { os.Exit(0) } } os.Exit(1) }, } ) func setIntFromEnv(receiver *int, val string) { converted, err := strconv.Atoi(val) if err == nil { *receiver = converted } } func setBoolFromEnv(receiver *bool, val string) { converted, err := strconv.ParseBool(strings.TrimSpace(val)) if err == nil { *receiver = converted } } func checkServeParamsFromEnvFiles(configDir string) { //nolint:gocyclo // The logger is not yet initialized here, we have no way to report errors. envd := filepath.Join(configDir, "env.d") entries, err := os.ReadDir(envd) if err != nil { return } for _, entry := range entries { info, err := entry.Info() if err == nil && info.Mode().IsRegular() { envFile := filepath.Join(envd, entry.Name()) if info.Size() > envFileMaxSize { continue } envVars, err := gotenv.Read(envFile) if err != nil { return } for k, v := range envVars { if _, isSet := os.LookupEnv(k); isSet { continue } switch k { case "SFTPGO_LOG_FILE_PATH": logFilePath = v case "SFTPGO_LOG_MAX_SIZE": setIntFromEnv(&logMaxSize, v) case "SFTPGO_LOG_MAX_BACKUPS": setIntFromEnv(&logMaxBackups, v) case "SFTPGO_LOG_MAX_AGE": setIntFromEnv(&logMaxAge, v) case "SFTPGO_LOG_COMPRESS": setBoolFromEnv(&logCompress, v) case "SFTPGO_LOG_LEVEL": logLevel = v case "SFTPGO_LOG_UTC_TIME": setBoolFromEnv(&logUTCTime, v) case "SFTPGO_CONFIG_FILE": configFile = v case "SFTPGO_LOADDATA_FROM": loadDataFrom = v case "SFTPGO_LOADDATA_MODE": setIntFromEnv(&loadDataMode, v) case "SFTPGO_LOADDATA_CLEAN": setBoolFromEnv(&loadDataClean, v) case "SFTPGO_LOADDATA_QUOTA_SCAN": setIntFromEnv(&loadDataQuotaScan, v) case "SFTPGO_GRACE_TIME": setIntFromEnv(&graceTime, v) } } } } } func init() { rootCmd.AddCommand(serveCmd) addServeFlags(serveCmd) } ================================================ FILE: internal/cmd/service_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "github.com/spf13/cobra" ) var ( serviceCmd = &cobra.Command{ Use: "service", Short: "Manage the SFTPGo Windows Service", } ) func init() { rootCmd.AddCommand(serviceCmd) } ================================================ FILE: internal/cmd/signals_unix.go ================================================ // Copyright (C) 2025 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !windows package cmd import ( "os" "os/signal" "syscall" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" ) func registerSignals() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) go func() { for sig := range c { switch sig { case syscall.SIGINT, syscall.SIGTERM: logger.DebugToConsole("Received interrupt request") plugin.Handler.Cleanup() os.Exit(0) } } }() } ================================================ FILE: internal/cmd/signals_windows.go ================================================ // Copyright (C) 2025 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "os" "os/signal" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" ) func registerSignals() { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { for range c { logger.DebugToConsole("Received interrupt request") plugin.Handler.Cleanup() os.Exit(0) } }() } ================================================ FILE: internal/cmd/smtptest.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "os" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( smtpTestRecipient string smtpTestCmd = &cobra.Command{ Use: "smtptest", Short: "Test the SMTP configuration", Long: `SFTPGo will try to send a test email to the specified recipient. If the SMTP configuration is correct you should receive this email.`, Run: func(_ *cobra.Command, _ []string) { logger.DisableLogger() logger.EnableConsoleLogger(zerolog.DebugLevel) configDir = util.CleanDirInput(configDir) err := config.LoadConfig(configDir, configFile) if err != nil { logger.ErrorToConsole("Unable to load configuration: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, false) if err != nil { logger.ErrorToConsole("error initializing data provider: %v", err) os.Exit(1) } smtpConfig := config.GetSMTPConfig() smtpConfig.Debug = 1 err = smtpConfig.Initialize(configDir, false) if err != nil { logger.ErrorToConsole("unable to initialize SMTP configuration: %v", err) os.Exit(1) } err = smtp.SendEmail([]string{smtpTestRecipient}, nil, "SFTPGo - Testing Email Settings", "It appears your SFTPGo email is setup correctly!", smtp.EmailContentTypeTextPlain) if err != nil { logger.WarnToConsole("Error sending email: %v", err) os.Exit(1) } logger.InfoToConsole("No errors were reported while sending the test email. Please check your inbox to make sure.") }, } ) func init() { addConfigFlags(smtpTestCmd) smtpTestCmd.Flags().StringVar(&smtpTestRecipient, "recipient", "", `email address to send the test e-mail to`) smtpTestCmd.MarkFlagRequired("recipient") //nolint:errcheck rootCmd.AddCommand(smtpTestCmd) } ================================================ FILE: internal/cmd/start_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "fmt" "os" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( startCmd = &cobra.Command{ Use: "start", Short: "Start the SFTPGo Windows Service", Run: func(_ *cobra.Command, _ []string) { configDir = util.CleanDirInput(configDir) checkServeParamsFromEnvFiles(configDir) service.SetGraceTime(graceTime) s := service.Service{ ConfigDir: configDir, ConfigFile: configFile, LogFilePath: logFilePath, LogMaxSize: logMaxSize, LogMaxBackups: logMaxBackups, LogMaxAge: logMaxAge, LogCompress: logCompress, LogLevel: logLevel, LogUTCTime: logUTCTime, LoadDataFrom: loadDataFrom, LoadDataMode: loadDataMode, LoadDataQuotaScan: loadDataQuotaScan, LoadDataClean: loadDataClean, Shutdown: make(chan bool), } winService := service.WindowsService{ Service: s, } err := winService.RunService() if err != nil { fmt.Printf("Error starting service: %v\r\n", err) os.Exit(1) } else { fmt.Printf("Service started!\r\n") } }, } ) func init() { serviceCmd.AddCommand(startCmd) addServeFlags(startCmd) } ================================================ FILE: internal/cmd/status_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "fmt" "os" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/service" ) var ( statusCmd = &cobra.Command{ Use: "status", Short: "Retrieve the status for the SFTPGo Windows Service", Run: func(_ *cobra.Command, _ []string) { s := service.WindowsService{ Service: service.Service{ Shutdown: make(chan bool), }, } status, err := s.Status() if err != nil { fmt.Printf("Error querying service status: %v\r\n", err) os.Exit(1) } else { fmt.Printf("Service status: %q\r\n", status.String()) } }, } ) func init() { serviceCmd.AddCommand(statusCmd) } ================================================ FILE: internal/cmd/stop_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "fmt" "os" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/service" ) var ( stopCmd = &cobra.Command{ Use: "stop", Short: "Stop the SFTPGo Windows Service", Run: func(_ *cobra.Command, _ []string) { s := service.WindowsService{ Service: service.Service{ Shutdown: make(chan bool), }, } err := s.Stop() if err != nil { fmt.Printf("Error stopping service: %v\r\n", err) os.Exit(1) } else { fmt.Printf("Service stopped!\r\n") } }, } ) func init() { serviceCmd.AddCommand(stopCmd) } ================================================ FILE: internal/cmd/uninstall_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package cmd import ( "fmt" "os" "github.com/spf13/cobra" "github.com/drakkan/sftpgo/v2/internal/service" ) var ( uninstallCmd = &cobra.Command{ Use: "uninstall", Short: "Uninstall the SFTPGo Windows Service", Run: func(_ *cobra.Command, _ []string) { s := service.WindowsService{ Service: service.Service{ Shutdown: make(chan bool), }, } err := s.Uninstall() if err != nil { fmt.Printf("Error removing service: %v\r\n", err) os.Exit(1) } else { fmt.Printf("Service uninstalled\r\n") } }, } ) func init() { serviceCmd.AddCommand(uninstallCmd) } ================================================ FILE: internal/command/command.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package command provides command configuration for SFTPGo hooks package command import ( "fmt" "slices" "strings" "time" ) const ( minTimeout = 1 maxTimeout = 300 defaultTimeout = 30 ) // Supported hook names const ( HookFsActions = "fs_actions" HookProviderActions = "provider_actions" HookStartup = "startup" HookPostConnect = "post_connect" HookPostDisconnect = "post_disconnect" HookCheckPassword = "check_password" HookPreLogin = "pre_login" HookPostLogin = "post_login" HookExternalAuth = "external_auth" HookKeyboardInteractive = "keyboard_interactive" ) var ( config Config supportedHooks = []string{HookFsActions, HookProviderActions, HookStartup, HookPostConnect, HookPostDisconnect, HookCheckPassword, HookPreLogin, HookPostLogin, HookExternalAuth, HookKeyboardInteractive} ) // Command define the configuration for a specific commands type Command struct { // Path is the command path as defined in the hook configuration Path string `json:"path" mapstructure:"path"` // Timeout specifies a time limit, in seconds, for the command execution. // This value overrides the global timeout if set. // Do not use variables with the SFTPGO_ prefix to avoid conflicts with env // vars that SFTPGo sets Timeout int `json:"timeout" mapstructure:"timeout"` // Env defines environment variable for the command. // Each entry is of the form "key=value". // These values are added to the global environment variables if any Env []string `json:"env" mapstructure:"env"` // Args defines arguments to pass to the specified command Args []string `json:"args" mapstructure:"args"` // if not empty both command path and hook name must match Hook string `json:"hook" mapstructure:"hook"` } // Config defines the configuration for external commands such as // program based hooks type Config struct { // Timeout specifies a global time limit, in seconds, for the external commands execution Timeout int `json:"timeout" mapstructure:"timeout"` // Env defines environment variable for the commands. // Each entry is of the form "key=value". // Do not use variables with the SFTPGO_ prefix to avoid conflicts with env // vars that SFTPGo sets Env []string `json:"env" mapstructure:"env"` // Commands defines configuration for specific commands Commands []Command `json:"commands" mapstructure:"commands"` } func init() { config = Config{ Timeout: defaultTimeout, } } // Initialize configures commands func (c Config) Initialize() error { if c.Timeout < minTimeout || c.Timeout > maxTimeout { return fmt.Errorf("invalid timeout %v", c.Timeout) } for _, env := range c.Env { if len(strings.SplitN(env, "=", 2)) != 2 { return fmt.Errorf("invalid env var %q", env) } } for idx, cmd := range c.Commands { if cmd.Path == "" { return fmt.Errorf("invalid path %q", cmd.Path) } if cmd.Timeout == 0 { c.Commands[idx].Timeout = c.Timeout } else { if cmd.Timeout < minTimeout || cmd.Timeout > maxTimeout { return fmt.Errorf("invalid timeout %v for command %q", cmd.Timeout, cmd.Path) } } for _, env := range cmd.Env { if len(strings.SplitN(env, "=", 2)) != 2 { return fmt.Errorf("invalid env var %q for command %q", env, cmd.Path) } } // don't validate args, we allow to pass empty arguments if cmd.Hook != "" { if !slices.Contains(supportedHooks, cmd.Hook) { return fmt.Errorf("invalid hook name %q, supported values: %+v", cmd.Hook, supportedHooks) } } } config = c return nil } // GetConfig returns the configuration for the specified command func GetConfig(command, hook string) (time.Duration, []string, []string) { env := []string{} var args []string timeout := time.Duration(config.Timeout) * time.Second env = append(env, config.Env...) for _, cmd := range config.Commands { if cmd.Path == command { if cmd.Hook == "" || cmd.Hook == hook { timeout = time.Duration(cmd.Timeout) * time.Second env = append(env, cmd.Env...) args = cmd.Args break } } } return timeout, env, args } ================================================ FILE: internal/command/command_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package command import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCommandConfig(t *testing.T) { require.Equal(t, defaultTimeout, config.Timeout) cfg := Config{ Timeout: 10, Env: []string{"a=b"}, } err := cfg.Initialize() require.NoError(t, err) assert.Equal(t, cfg.Timeout, config.Timeout) assert.Equal(t, cfg.Env, config.Env) assert.Len(t, cfg.Commands, 0) timeout, env, args := GetConfig("cmd", "") assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout) assert.Contains(t, env, "a=b") assert.Len(t, args, 0) cfg.Commands = []Command{ { Path: "cmd1", Timeout: 30, Env: []string{"c=d"}, Args: []string{"1", "", "2"}, }, { Path: "cmd2", Timeout: 0, Env: []string{"e=f"}, }, } err = cfg.Initialize() require.NoError(t, err) assert.Equal(t, cfg.Timeout, config.Timeout) assert.Equal(t, cfg.Env, config.Env) if assert.Len(t, config.Commands, 2) { assert.Equal(t, cfg.Commands[0].Path, config.Commands[0].Path) assert.Equal(t, cfg.Commands[0].Timeout, config.Commands[0].Timeout) assert.Equal(t, cfg.Commands[0].Env, config.Commands[0].Env) assert.Equal(t, cfg.Commands[0].Args, config.Commands[0].Args) assert.Equal(t, cfg.Commands[1].Path, config.Commands[1].Path) assert.Equal(t, cfg.Timeout, config.Commands[1].Timeout) assert.Equal(t, cfg.Commands[1].Env, config.Commands[1].Env) assert.Equal(t, cfg.Commands[1].Args, config.Commands[1].Args) } timeout, env, args = GetConfig("cmd1", "") assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout) assert.Contains(t, env, "a=b") assert.Contains(t, env, "c=d") assert.NotContains(t, env, "e=f") if assert.Len(t, args, 3) { assert.Equal(t, "1", args[0]) assert.Empty(t, args[1]) assert.Equal(t, "2", args[2]) } timeout, env, args = GetConfig("cmd2", "") assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout) assert.Contains(t, env, "a=b") assert.NotContains(t, env, "c=d") assert.Contains(t, env, "e=f") assert.Len(t, args, 0) cfg.Commands = []Command{ { Path: "cmd1", Timeout: 30, Env: []string{"c=d"}, Args: []string{"1", "", "2"}, Hook: HookCheckPassword, }, { Path: "cmd1", Timeout: 0, Env: []string{"e=f"}, Hook: HookExternalAuth, }, } err = cfg.Initialize() require.NoError(t, err) timeout, env, args = GetConfig("cmd1", "") assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout) assert.Contains(t, env, "a=b") assert.NotContains(t, env, "c=d") assert.NotContains(t, env, "e=f") assert.Len(t, args, 0) timeout, env, args = GetConfig("cmd1", HookCheckPassword) assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout) assert.Contains(t, env, "a=b") assert.Contains(t, env, "c=d") assert.NotContains(t, env, "e=f") if assert.Len(t, args, 3) { assert.Equal(t, "1", args[0]) assert.Empty(t, args[1]) assert.Equal(t, "2", args[2]) } timeout, env, args = GetConfig("cmd1", HookExternalAuth) assert.Equal(t, time.Duration(cfg.Timeout)*time.Second, timeout) assert.Contains(t, env, "a=b") assert.NotContains(t, env, "c=d") assert.Contains(t, env, "e=f") assert.Len(t, args, 0) } func TestConfigErrors(t *testing.T) { c := Config{} err := c.Initialize() if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid timeout") } c.Timeout = 10 c.Env = []string{"a"} err = c.Initialize() if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid env var") } c.Env = nil c.Commands = []Command{ { Path: "", }, } err = c.Initialize() if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid path") } c.Commands = []Command{ { Path: "path", Timeout: 10000, }, } err = c.Initialize() if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid timeout") } c.Commands = []Command{ { Path: "path", Timeout: 30, Env: []string{"b"}, }, } err = c.Initialize() if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid env var") } c.Commands = []Command{ { Path: "path", Timeout: 30, Env: []string{"a=b"}, Hook: "invali", }, } err = c.Initialize() if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid hook name") } } ================================================ FILE: internal/common/actions.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "bytes" "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "os/exec" "path" "path/filepath" "slices" "strings" "sync/atomic" "time" "github.com/sftpgo/sdk" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" ) var ( errUnexpectedHTTResponse = errors.New("unexpected HTTP hook response code") hooksConcurrencyGuard = make(chan struct{}, 150) activeHooks atomic.Int32 ) func startNewHook() { activeHooks.Add(1) hooksConcurrencyGuard <- struct{}{} } func hookEnded() { activeHooks.Add(-1) <-hooksConcurrencyGuard } // ProtocolActions defines the action to execute on file operations and SSH commands type ProtocolActions struct { // Valid values are download, upload, pre-delete, delete, rename, ssh_cmd. Empty slice to disable ExecuteOn []string `json:"execute_on" mapstructure:"execute_on"` // Actions to be performed synchronously. // The pre-delete action is always executed synchronously while the other ones are asynchronous. // Executing an action synchronously means that SFTPGo will not return a result code to the client // (which is waiting for it) until your hook have completed its execution. ExecuteSync []string `json:"execute_sync" mapstructure:"execute_sync"` // Absolute path to an external program or an HTTP URL Hook string `json:"hook" mapstructure:"hook"` } var actionHandler ActionHandler = &defaultActionHandler{} // InitializeActionHandler lets the user choose an action handler implementation. // // Do NOT call this function after application initialization. func InitializeActionHandler(handler ActionHandler) { actionHandler = handler } // ExecutePreAction executes a pre-* action and returns the result. // The returned status has the following meaning: // - 0 not executed // - 1 executed using an external hook // - 2 executed using the event manager func ExecutePreAction(conn *BaseConnection, operation, filePath, virtualPath string, fileSize int64, openFlags int) (int, error) { var event *notifier.FsEvent hasNotifiersPlugin := plugin.Handler.HasNotifiers() hasHook := slices.Contains(Config.Actions.ExecuteOn, operation) hasRules := eventManager.hasFsRules() if !hasHook && !hasNotifiersPlugin && !hasRules { return 0, nil } dateTime := time.Now() event = newActionNotification(&conn.User, operation, filePath, virtualPath, "", "", "", conn.protocol, conn.GetRemoteIP(), conn.ID, fileSize, openFlags, conn.getNotificationStatus(nil), 0, dateTime, nil) if hasNotifiersPlugin { plugin.Handler.NotifyFsEvent(event) } if hasRules { params := EventParams{ Name: event.Username, Groups: conn.User.Groups, Event: event.Action, Status: event.Status, VirtualPath: event.VirtualPath, FsPath: event.Path, VirtualTargetPath: event.VirtualTargetPath, FsTargetPath: event.TargetPath, ObjectName: path.Base(event.VirtualPath), Extension: path.Ext(event.VirtualPath), FileSize: event.FileSize, Protocol: event.Protocol, IP: event.IP, Role: event.Role, Timestamp: dateTime, Email: conn.User.Email, Object: nil, } executedSync, err := eventManager.handleFsEvent(params) if executedSync { return 2, err } } if !hasHook { return 0, nil } return actionHandler.Handle(event) } // ExecuteActionNotification executes the defined hook, if any, for the specified action func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtualPath, target, virtualTarget, sshCmd string, fileSize int64, err error, elapsed int64, metadata map[string]string, ) error { hasNotifiersPlugin := plugin.Handler.HasNotifiers() hasHook := slices.Contains(Config.Actions.ExecuteOn, operation) hasRules := eventManager.hasFsRules() if !hasHook && !hasNotifiersPlugin && !hasRules { return nil } dateTime := time.Now() notification := newActionNotification(&conn.User, operation, filePath, virtualPath, target, virtualTarget, sshCmd, conn.protocol, conn.GetRemoteIP(), conn.ID, fileSize, 0, conn.getNotificationStatus(err), elapsed, dateTime, metadata) if hasNotifiersPlugin { plugin.Handler.NotifyFsEvent(notification) } if hasRules { params := EventParams{ Name: notification.Username, Groups: conn.User.Groups, Event: notification.Action, Status: notification.Status, VirtualPath: notification.VirtualPath, FsPath: notification.Path, VirtualTargetPath: notification.VirtualTargetPath, FsTargetPath: notification.TargetPath, ObjectName: path.Base(notification.VirtualPath), Extension: path.Ext(notification.VirtualPath), FileSize: notification.FileSize, Elapsed: notification.Elapsed, Protocol: notification.Protocol, IP: notification.IP, Role: notification.Role, Timestamp: dateTime, Email: conn.User.Email, Object: nil, Metadata: metadata, } if err != nil { params.AddError(fmt.Errorf("%q failed: %w", params.Event, err)) } executedSync, err := eventManager.handleFsEvent(params) if executedSync { return err } } if hasHook { if slices.Contains(Config.Actions.ExecuteSync, operation) { _, err := actionHandler.Handle(notification) return err } go func() { startNewHook() defer hookEnded() actionHandler.Handle(notification) //nolint:errcheck }() } return nil } // ActionHandler handles a notification for a Protocol Action. type ActionHandler interface { Handle(notification *notifier.FsEvent) (int, error) } func newActionNotification( user *dataprovider.User, operation, filePath, virtualPath, target, virtualTarget, sshCmd, protocol, ip, sessionID string, fileSize int64, openFlags, status int, elapsed int64, datetime time.Time, metadata map[string]string, ) *notifier.FsEvent { var bucket, endpoint string fsConfig := user.GetFsConfigForPath(virtualPath) switch fsConfig.Provider { case sdk.S3FilesystemProvider: bucket = fsConfig.S3Config.Bucket endpoint = fsConfig.S3Config.Endpoint case sdk.GCSFilesystemProvider: bucket = fsConfig.GCSConfig.Bucket case sdk.AzureBlobFilesystemProvider: bucket = fsConfig.AzBlobConfig.Container if fsConfig.AzBlobConfig.Endpoint != "" { endpoint = fsConfig.AzBlobConfig.Endpoint } case sdk.SFTPFilesystemProvider: endpoint = fsConfig.SFTPConfig.Endpoint case sdk.HTTPFilesystemProvider: endpoint = fsConfig.HTTPConfig.Endpoint } return ¬ifier.FsEvent{ Action: operation, Username: user.Username, Path: filePath, TargetPath: target, VirtualPath: virtualPath, VirtualTargetPath: virtualTarget, SSHCmd: sshCmd, FileSize: fileSize, FsProvider: int(fsConfig.Provider), Bucket: bucket, Endpoint: endpoint, Status: status, Protocol: protocol, IP: ip, SessionID: sessionID, OpenFlags: openFlags, Role: user.Role, Timestamp: datetime.UnixNano(), Elapsed: elapsed, Metadata: metadata, } } type defaultActionHandler struct{} func (h *defaultActionHandler) Handle(event *notifier.FsEvent) (int, error) { if !slices.Contains(Config.Actions.ExecuteOn, event.Action) { return 0, nil } if Config.Actions.Hook == "" { logger.Warn(event.Protocol, "", "Unable to send notification, no hook is defined") return 0, nil } if strings.HasPrefix(Config.Actions.Hook, "http") { err := h.handleHTTP(event) return 1, err } err := h.handleCommand(event) return 1, err } func (h *defaultActionHandler) handleHTTP(event *notifier.FsEvent) error { u, err := url.Parse(Config.Actions.Hook) if err != nil { logger.Error(event.Protocol, "", "Invalid hook %q for operation %q: %v", Config.Actions.Hook, event.Action, err) return err } startTime := time.Now() respCode := 0 var b bytes.Buffer _ = json.NewEncoder(&b).Encode(event) resp, err := httpclient.RetryablePost(Config.Actions.Hook, "application/json", &b) if err == nil { respCode = resp.StatusCode resp.Body.Close() if respCode != http.StatusOK { err = errUnexpectedHTTResponse } } logger.Debug(event.Protocol, "", "notified operation %q to URL: %s status code: %d, elapsed: %s err: %v", event.Action, u.Redacted(), respCode, time.Since(startTime), err) return err } func (h *defaultActionHandler) handleCommand(event *notifier.FsEvent) error { if !filepath.IsAbs(Config.Actions.Hook) { err := fmt.Errorf("invalid notification command %q", Config.Actions.Hook) logger.Warn(event.Protocol, "", "unable to execute notification command: %v", err) return err } timeout, env, args := command.GetConfig(Config.Actions.Hook, command.HookFsActions) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, Config.Actions.Hook, args...) cmd.Env = append(env, notificationAsEnvVars(event)...) startTime := time.Now() err := cmd.Run() logger.Debug(event.Protocol, "", "executed command %q, elapsed: %s, error: %v", Config.Actions.Hook, time.Since(startTime), err) return err } func notificationAsEnvVars(event *notifier.FsEvent) []string { result := []string{ fmt.Sprintf("SFTPGO_ACTION=%s", event.Action), fmt.Sprintf("SFTPGO_ACTION_USERNAME=%s", event.Username), fmt.Sprintf("SFTPGO_ACTION_PATH=%s", event.Path), fmt.Sprintf("SFTPGO_ACTION_TARGET=%s", event.TargetPath), fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_PATH=%s", event.VirtualPath), fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_TARGET=%s", event.VirtualTargetPath), fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%s", event.SSHCmd), fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%d", event.FileSize), fmt.Sprintf("SFTPGO_ACTION_ELAPSED=%d", event.Elapsed), fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%d", event.FsProvider), fmt.Sprintf("SFTPGO_ACTION_BUCKET=%s", event.Bucket), fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%s", event.Endpoint), fmt.Sprintf("SFTPGO_ACTION_STATUS=%d", event.Status), fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%s", event.Protocol), fmt.Sprintf("SFTPGO_ACTION_IP=%s", event.IP), fmt.Sprintf("SFTPGO_ACTION_SESSION_ID=%s", event.SessionID), fmt.Sprintf("SFTPGO_ACTION_OPEN_FLAGS=%d", event.OpenFlags), fmt.Sprintf("SFTPGO_ACTION_TIMESTAMP=%d", event.Timestamp), fmt.Sprintf("SFTPGO_ACTION_ROLE=%s", event.Role), } if len(event.Metadata) > 0 { data, err := json.Marshal(event.Metadata) if err == nil { result = append(result, fmt.Sprintf("SFTPGO_ACTION_METADATA=%s", data)) } } return result } ================================================ FILE: internal/common/actions_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "errors" "fmt" "os" "os/exec" "path/filepath" "runtime" "testing" "time" "github.com/lithammer/shortuuid/v4" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/sftpgo/sdk/plugin/notifier" "github.com/stretchr/testify/assert" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/vfs" ) func TestNewActionNotification(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "username", }, } user.FsConfig.Provider = sdk.LocalFilesystemProvider user.FsConfig.S3Config = vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: "s3bucket", Endpoint: "endpoint", }, } user.FsConfig.GCSConfig = vfs.GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ Bucket: "gcsbucket", }, } user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{ BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ Container: "azcontainer", Endpoint: "azendpoint", }, } user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: "sftpendpoint", }, } user.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: "httpendpoint", }, } c := NewBaseConnection("id", ProtocolSSH, "", "", user) sessionID := xid.New().String() a := newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID, 123, 0, c.getNotificationStatus(errors.New("fake error")), 0, time.Now(), nil) assert.Equal(t, user.Username, a.Username) assert.Equal(t, 0, len(a.Bucket)) assert.Equal(t, 0, len(a.Endpoint)) assert.Equal(t, 2, a.Status) user.FsConfig.Provider = sdk.S3FilesystemProvider a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSSH, "", sessionID, 123, 0, c.getNotificationStatus(nil), 0, time.Now(), nil) assert.Equal(t, "s3bucket", a.Bucket) assert.Equal(t, "endpoint", a.Endpoint) assert.Equal(t, 1, a.Status) user.FsConfig.Provider = sdk.GCSFilesystemProvider a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, 123, 0, c.getNotificationStatus(ErrQuotaExceeded), 0, time.Now(), nil) assert.Equal(t, "gcsbucket", a.Bucket) assert.Equal(t, 0, len(a.Endpoint)) assert.Equal(t, 3, a.Status) a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, 123, 0, c.getNotificationStatus(fmt.Errorf("wrapper quota error: %w", ErrQuotaExceeded)), 0, time.Now(), nil) assert.Equal(t, "gcsbucket", a.Bucket) assert.Equal(t, 0, len(a.Endpoint)) assert.Equal(t, 3, a.Status) user.FsConfig.Provider = sdk.HTTPFilesystemProvider a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSSH, "", sessionID, 123, 0, c.getNotificationStatus(nil), 0, time.Now(), nil) assert.Equal(t, "httpendpoint", a.Endpoint) assert.Equal(t, 1, a.Status) user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, 123, 0, c.getNotificationStatus(nil), 0, time.Now(), nil) assert.Equal(t, "azcontainer", a.Bucket) assert.Equal(t, "azendpoint", a.Endpoint) assert.Equal(t, 1, a.Status) a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, 123, os.O_APPEND, c.getNotificationStatus(nil), 0, time.Now(), nil) assert.Equal(t, "azcontainer", a.Bucket) assert.Equal(t, "azendpoint", a.Endpoint) assert.Equal(t, 1, a.Status) assert.Equal(t, os.O_APPEND, a.OpenFlags) user.FsConfig.Provider = sdk.SFTPFilesystemProvider a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID, 123, 0, c.getNotificationStatus(nil), 0, time.Now(), nil) assert.Equal(t, "sftpendpoint", a.Endpoint) } func TestActionHTTP(t *testing.T) { actionsCopy := Config.Actions Config.Actions = ProtocolActions{ ExecuteOn: []string{operationDownload}, Hook: fmt.Sprintf("http://%v", httpAddr), } user := &dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "username", }, } a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", xid.New().String(), 123, 0, 1, 0, time.Now(), nil) status, err := actionHandler.Handle(a) assert.NoError(t, err) assert.Equal(t, 1, status) Config.Actions.Hook = "http://invalid:1234" status, err = actionHandler.Handle(a) assert.Error(t, err) assert.Equal(t, 1, status) Config.Actions.Hook = fmt.Sprintf("http://%v/404", httpAddr) status, err = actionHandler.Handle(a) if assert.Error(t, err) { assert.EqualError(t, err, errUnexpectedHTTResponse.Error()) } assert.Equal(t, 1, status) Config.Actions = actionsCopy } func TestActionCMD(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } actionsCopy := Config.Actions hookCmd, err := exec.LookPath("true") assert.NoError(t, err) Config.Actions = ProtocolActions{ ExecuteOn: []string{operationDownload}, Hook: hookCmd, } user := &dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "username", }, } sessionID := shortuuid.New() a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID, 123, 0, 1, 0, time.Now(), map[string]string{"key": "value"}) status, err := actionHandler.Handle(a) assert.NoError(t, err) assert.Equal(t, 1, status) c := NewBaseConnection("id", ProtocolSFTP, "", "", *user) err = ExecuteActionNotification(c, OperationSSHCmd, "path", "vpath", "target", "vtarget", "sha1sum", 0, nil, 0, nil) assert.NoError(t, err) err = ExecuteActionNotification(c, operationDownload, "path", "vpath", "", "", "", 0, nil, 0, nil) assert.NoError(t, err) Config.Actions = actionsCopy } func TestWrongActions(t *testing.T) { actionsCopy := Config.Actions badCommand := "/bad/command" if runtime.GOOS == osWindows { badCommand = "C:\\bad\\command" } Config.Actions = ProtocolActions{ ExecuteOn: []string{operationUpload}, Hook: badCommand, } user := &dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "username", }, } a := newActionNotification(user, operationUpload, "", "", "", "", "", ProtocolSFTP, "", xid.New().String(), 123, 0, 1, 0, time.Now(), nil) status, err := actionHandler.Handle(a) assert.Error(t, err, "action with bad command must fail") assert.Equal(t, 1, status) a.Action = operationDelete status, err = actionHandler.Handle(a) assert.NoError(t, err) assert.Equal(t, 0, status) Config.Actions.Hook = "http://foo\x7f.com/" a.Action = operationUpload status, err = actionHandler.Handle(a) assert.Error(t, err, "action with bad url must fail") assert.Equal(t, 1, status) Config.Actions.Hook = "" status, err = actionHandler.Handle(a) assert.NoError(t, err) assert.Equal(t, 0, status) Config.Actions.Hook = "relative path" status, err = actionHandler.Handle(a) if assert.Error(t, err) { assert.EqualError(t, err, fmt.Sprintf("invalid notification command %q", Config.Actions.Hook)) } assert.Equal(t, 1, status) Config.Actions = actionsCopy } func TestPreDeleteAction(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } actionsCopy := Config.Actions hookCmd, err := exec.LookPath("true") assert.NoError(t, err) Config.Actions = ProtocolActions{ ExecuteOn: []string{operationPreDelete}, Hook: "missing hook", } homeDir := filepath.Join(os.TempDir(), "test_user") err = os.MkdirAll(homeDir, os.ModePerm) assert.NoError(t, err) user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "username", HomeDir: homeDir, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := vfs.NewOsFs("id", homeDir, "", nil) c := NewBaseConnection("id", ProtocolSFTP, "", "", user) testfile := filepath.Join(user.HomeDir, "testfile") err = os.WriteFile(testfile, []byte("test"), os.ModePerm) assert.NoError(t, err) info, err := os.Stat(testfile) assert.NoError(t, err) err = c.RemoveFile(fs, testfile, "testfile", info) assert.ErrorIs(t, err, c.GetPermissionDeniedError()) assert.FileExists(t, testfile) Config.Actions.Hook = hookCmd err = c.RemoveFile(fs, testfile, "testfile", info) assert.NoError(t, err) assert.NoFileExists(t, testfile) os.RemoveAll(homeDir) Config.Actions = actionsCopy } func TestUnconfiguredHook(t *testing.T) { actionsCopy := Config.Actions Config.Actions = ProtocolActions{ ExecuteOn: []string{operationDownload}, Hook: "", } pluginsConfig := []plugin.Config{ { Type: "notifier", }, } err := plugin.Initialize(pluginsConfig, "debug") assert.Error(t, err) assert.True(t, plugin.Handler.HasNotifiers()) c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) status, err := ExecutePreAction(c, OperationPreDownload, "", "", 0, 0) assert.NoError(t, err) assert.Equal(t, status, 0) status, err = ExecutePreAction(c, operationPreDelete, "", "", 0, 0) assert.NoError(t, err) assert.Equal(t, status, 0) err = ExecuteActionNotification(c, operationDownload, "", "", "", "", "", 0, nil, 0, nil) assert.NoError(t, err) err = plugin.Initialize(nil, "debug") assert.NoError(t, err) assert.False(t, plugin.Handler.HasNotifiers()) Config.Actions = actionsCopy } type actionHandlerStub struct { called bool } func (h *actionHandlerStub) Handle(_ *notifier.FsEvent) (int, error) { h.called = true return 1, nil } func TestInitializeActionHandler(t *testing.T) { handler := &actionHandlerStub{} InitializeActionHandler(handler) t.Cleanup(func() { InitializeActionHandler(&defaultActionHandler{}) }) status, err := actionHandler.Handle(¬ifier.FsEvent{}) assert.NoError(t, err) assert.True(t, handler.called) assert.Equal(t, 1, status) } ================================================ FILE: internal/common/clientsmap.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "sync" "sync/atomic" "github.com/drakkan/sftpgo/v2/internal/logger" ) // clienstMap is a struct containing the map of the connected clients type clientsMap struct { totalConnections atomic.Int32 mu sync.RWMutex clients map[string]int } func (c *clientsMap) add(source string) { c.totalConnections.Add(1) c.mu.Lock() defer c.mu.Unlock() c.clients[source]++ } func (c *clientsMap) remove(source string) { c.mu.Lock() defer c.mu.Unlock() if val, ok := c.clients[source]; ok { c.totalConnections.Add(-1) c.clients[source]-- if val > 1 { return } delete(c.clients, source) } else { logger.Warn(logSender, "", "cannot remove client %v it is not mapped", source) } } func (c *clientsMap) getTotal() int32 { return c.totalConnections.Load() } func (c *clientsMap) getTotalFrom(source string) int { c.mu.RLock() defer c.mu.RUnlock() return c.clients[source] } ================================================ FILE: internal/common/clientsmap_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "testing" "github.com/stretchr/testify/assert" ) func TestClientsMap(t *testing.T) { m := clientsMap{ clients: make(map[string]int), } ip1 := "192.168.1.1" ip2 := "192.168.1.2" m.add(ip1) assert.Equal(t, int32(1), m.getTotal()) assert.Equal(t, 1, m.getTotalFrom(ip1)) assert.Equal(t, 0, m.getTotalFrom(ip2)) m.add(ip1) m.add(ip2) assert.Equal(t, int32(3), m.getTotal()) assert.Equal(t, 2, m.getTotalFrom(ip1)) assert.Equal(t, 1, m.getTotalFrom(ip2)) m.add(ip1) m.add(ip1) m.add(ip2) assert.Equal(t, int32(6), m.getTotal()) assert.Equal(t, 4, m.getTotalFrom(ip1)) assert.Equal(t, 2, m.getTotalFrom(ip2)) m.remove(ip2) assert.Equal(t, int32(5), m.getTotal()) assert.Equal(t, 4, m.getTotalFrom(ip1)) assert.Equal(t, 1, m.getTotalFrom(ip2)) m.remove("unknown") assert.Equal(t, int32(5), m.getTotal()) assert.Equal(t, 4, m.getTotalFrom(ip1)) assert.Equal(t, 1, m.getTotalFrom(ip2)) m.remove(ip2) assert.Equal(t, int32(4), m.getTotal()) assert.Equal(t, 4, m.getTotalFrom(ip1)) assert.Equal(t, 0, m.getTotalFrom(ip2)) m.remove(ip1) m.remove(ip1) m.remove(ip1) assert.Equal(t, int32(1), m.getTotal()) assert.Equal(t, 1, m.getTotalFrom(ip1)) assert.Equal(t, 0, m.getTotalFrom(ip2)) m.remove(ip1) assert.Equal(t, int32(0), m.getTotal()) assert.Equal(t, 0, m.getTotalFrom(ip1)) assert.Equal(t, 0, m.getTotalFrom(ip2)) } ================================================ FILE: internal/common/common.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package common defines code shared among file transfer packages and protocols package common import ( "context" "errors" "fmt" "io" "net" "net/http" "net/url" "os" "os/exec" "path/filepath" "slices" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/pires/go-proxyproto" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) // constants const ( logSender = "common" uploadLogSender = "Upload" downloadLogSender = "Download" renameLogSender = "Rename" rmdirLogSender = "Rmdir" mkdirLogSender = "Mkdir" symlinkLogSender = "Symlink" removeLogSender = "Remove" chownLogSender = "Chown" chmodLogSender = "Chmod" chtimesLogSender = "Chtimes" copyLogSender = "Copy" truncateLogSender = "Truncate" operationDownload = "download" operationUpload = "upload" operationFirstDownload = "first-download" operationFirstUpload = "first-upload" operationDelete = "delete" operationCopy = "copy" // Pre-download action name OperationPreDownload = "pre-download" // Pre-upload action name OperationPreUpload = "pre-upload" operationPreDelete = "pre-delete" operationRename = "rename" operationMkdir = "mkdir" operationRmdir = "rmdir" // SSH command action name OperationSSHCmd = "ssh_cmd" chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS idleTimeoutCheckInterval = 3 * time.Minute periodicTimeoutCheckInterval = 1 * time.Minute ) // Stat flags const ( StatAttrUIDGID = 1 StatAttrPerms = 2 StatAttrTimes = 4 StatAttrSize = 8 ) // Transfer types const ( TransferUpload = iota TransferDownload ) // Supported protocols const ( ProtocolSFTP = "SFTP" ProtocolSCP = "SCP" ProtocolSSH = "SSH" ProtocolFTP = "FTP" ProtocolWebDAV = "DAV" ProtocolHTTP = "HTTP" ProtocolHTTPShare = "HTTPShare" ProtocolDataRetention = "DataRetention" ProtocolOIDC = "OIDC" protocolEventAction = "EventAction" ) // Upload modes const ( UploadModeStandard = 0 UploadModeAtomic = 1 UploadModeAtomicWithResume = 2 UploadModeS3StoreOnError = 4 UploadModeGCSStoreOnError = 8 UploadModeAzureBlobStoreOnError = 16 ) func init() { Connections.clients = clientsMap{ clients: make(map[string]int), } Connections.transfers = clientsMap{ clients: make(map[string]int), } Connections.perUserConns = make(map[string]int) Connections.mapping = make(map[string]int) Connections.sshMapping = make(map[string]int) } // errors definitions var ( ErrPermissionDenied = errors.New("permission denied") ErrNotExist = errors.New("no such file or directory") ErrOpUnsupported = errors.New("operation unsupported") ErrGenericFailure = errors.New("failure") ErrQuotaExceeded = errors.New("denying write due to space limit") ErrReadQuotaExceeded = errors.New("denying read due to quota limit") ErrConnectionDenied = errors.New("you are not allowed to connect") ErrNoBinding = errors.New("no binding configured") ErrCrtRevoked = errors.New("your certificate has been revoked") ErrNoCredentials = errors.New("no credential provided") ErrInternalFailure = errors.New("internal failure") ErrTransferAborted = errors.New("transfer aborted") ErrShuttingDown = errors.New("the service is shutting down") errNoTransfer = errors.New("requested transfer not found") errTransferMismatch = errors.New("transfer mismatch") ) var ( // Config is the configuration for the supported protocols Config Configuration // Connections is the list of active connections Connections ActiveConnections // QuotaScans is the list of active quota scans QuotaScans ActiveScans transfersChecker TransfersChecker supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC} disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP} // the map key is the protocol, for each protocol we can have multiple rate limiters rateLimiters map[string][]*rateLimiter isShuttingDown atomic.Bool ftpLoginCommands = []string{"PASS", "USER"} fnUpdateBranding func(*dataprovider.BrandingConfigs) ) // SetUpdateBrandingFn sets the function to call to update branding configs. func SetUpdateBrandingFn(fn func(*dataprovider.BrandingConfigs)) { fnUpdateBranding = fn } // Initialize sets the common configuration func Initialize(c Configuration, isShared int) error { isShuttingDown.Store(false) util.SetUmask(c.Umask) version.SetConfig(c.ServerVersion) dataprovider.SetTZ(c.TZ) Config = c Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true) Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true) Config.ProxyAllowed = util.RemoveDuplicates(Config.ProxyAllowed, true) Config.idleLoginTimeout = 2 * time.Minute Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute startPeriodicChecks(periodicTimeoutCheckInterval, isShared) Config.defender = nil Config.allowList = nil Config.rateLimitersList = nil rateLimiters = make(map[string][]*rateLimiter) for _, rlCfg := range c.RateLimitersConfig { if rlCfg.isEnabled() { if err := rlCfg.validate(); err != nil { return fmt.Errorf("rate limiters initialization error: %w", err) } rateLimiter := rlCfg.getLimiter() for _, protocol := range rlCfg.Protocols { rateLimiters[protocol] = append(rateLimiters[protocol], rateLimiter) } } } if len(rateLimiters) > 0 { rateLimitersList, err := dataprovider.NewIPList(dataprovider.IPListTypeRateLimiterSafeList) if err != nil { return fmt.Errorf("unable to initialize ratelimiters list: %w", err) } Config.rateLimitersList = rateLimitersList } if c.DefenderConfig.Enabled { if !slices.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) { return fmt.Errorf("unsupported defender driver %q", c.DefenderConfig.Driver) } var defender Defender var err error switch c.DefenderConfig.Driver { case DefenderDriverProvider: defender, err = newDBDefender(&c.DefenderConfig) default: defender, err = newInMemoryDefender(&c.DefenderConfig) } if err != nil { return fmt.Errorf("defender initialization error: %v", err) } logger.Info(logSender, "", "defender initialized with config %+v", c.DefenderConfig) Config.defender = defender } if c.AllowListStatus > 0 { allowList, err := dataprovider.NewIPList(dataprovider.IPListTypeAllowList) if err != nil { return fmt.Errorf("unable to initialize the allow list: %w", err) } logger.Info(logSender, "", "allow list initialized") Config.allowList = allowList } if err := c.initializeProxyProtocol(); err != nil { return err } if err := c.EventManager.validate(); err != nil { return err } vfs.SetTempPath(c.TempPath) dataprovider.SetTempPath(c.TempPath) vfs.SetAllowSelfConnections(c.AllowSelfConnections) vfs.SetRenameMode(c.RenameMode) vfs.SetReadMetadataMode(c.Metadata.Read) vfs.SetResumeMaxSize(c.ResumeMaxSize) vfs.SetUploadMode(c.UploadMode) dataprovider.SetAllowSelfConnections(c.AllowSelfConnections) dataprovider.EnabledActionCommands = c.EventManager.EnabledCommands transfersChecker = getTransfersChecker(isShared) return nil } // CheckClosing returns an error if the service is closing func CheckClosing() error { if isShuttingDown.Load() { return ErrShuttingDown } return nil } // WaitForTransfers waits, for the specified grace time, for currently ongoing // client-initiated transfer sessions to completes. // A zero graceTime means no wait func WaitForTransfers(graceTime int) { if graceTime == 0 { return } if isShuttingDown.Swap(true) { return } if activeHooks.Load() == 0 && getActiveConnections() == 0 { return } graceTimer := time.NewTimer(time.Duration(graceTime) * time.Second) ticker := time.NewTicker(3 * time.Second) for { select { case <-ticker.C: hooks := activeHooks.Load() logger.Info(logSender, "", "active hooks: %d", hooks) if hooks == 0 && getActiveConnections() == 0 { logger.Info(logSender, "", "no more active connections, graceful shutdown") ticker.Stop() graceTimer.Stop() return } case <-graceTimer.C: logger.Info(logSender, "", "grace time expired, hard shutdown") ticker.Stop() return } } } // getActiveConnections returns the number of connections with active transfers func getActiveConnections() int { var activeConns int Connections.RLock() for _, c := range Connections.connections { if len(c.GetTransfers()) > 0 { activeConns++ } } Connections.RUnlock() logger.Info(logSender, "", "number of connections with active transfers: %d", activeConns) return activeConns } // LimitRate blocks until all the configured rate limiters // allow one event to happen. // It returns an error if the time to wait exceeds the max // allowed delay func LimitRate(protocol, ip string) (time.Duration, error) { if Config.rateLimitersList != nil { isListed, _, err := Config.rateLimitersList.IsListed(ip, protocol) if err == nil && isListed { return 0, nil } } for _, limiter := range rateLimiters[protocol] { if delay, err := limiter.Wait(ip, protocol); err != nil { logger.Debug(logSender, "", "protocol %s ip %s: %v", protocol, ip, err) return delay, err } } return 0, nil } // Reload reloads the whitelist, the IP filter plugin and the defender's block and safe lists func Reload() error { plugin.Handler.ReloadFilter() return nil } // DelayLogin applies the configured login delay func DelayLogin(err error) { if Config.defender != nil { Config.defender.DelayLogin(err) } } // IsBanned returns true if the specified IP address is banned func IsBanned(ip, protocol string) bool { if plugin.Handler.IsIPBanned(ip, protocol) { return true } if Config.defender == nil { return false } return Config.defender.IsBanned(ip, protocol) } // GetDefenderBanTime returns the ban time for the given IP // or nil if the IP is not banned or the defender is disabled func GetDefenderBanTime(ip string) (*time.Time, error) { if Config.defender == nil { return nil, nil } return Config.defender.GetBanTime(ip) } // GetDefenderHosts returns hosts that are banned or for which some violations have been detected func GetDefenderHosts() ([]dataprovider.DefenderEntry, error) { if Config.defender == nil { return nil, nil } return Config.defender.GetHosts() } // GetDefenderHost returns a defender host by ip, if any func GetDefenderHost(ip string) (dataprovider.DefenderEntry, error) { if Config.defender == nil { return dataprovider.DefenderEntry{}, errors.New("defender is disabled") } return Config.defender.GetHost(ip) } // DeleteDefenderHost removes the specified IP address from the defender lists func DeleteDefenderHost(ip string) bool { if Config.defender == nil { return false } return Config.defender.DeleteHost(ip) } // GetDefenderScore returns the score for the given IP func GetDefenderScore(ip string) (int, error) { if Config.defender == nil { return 0, nil } return Config.defender.GetScore(ip) } // AddDefenderEvent adds the specified defender event for the given IP. // Returns true if the IP is in the defender's safe list. func AddDefenderEvent(ip, protocol string, event HostEvent) bool { if Config.defender == nil { return false } return Config.defender.AddEvent(ip, protocol, event) } func reloadProviderConfigs() { configs, err := dataprovider.GetConfigs() if err != nil { logger.Error(logSender, "", "unable to load config from provider: %v", err) return } configs.SetNilsToEmpty() if fnUpdateBranding != nil { fnUpdateBranding(configs.Branding) } if err := configs.SMTP.TryDecrypt(); err != nil { logger.Error(logSender, "", "unable to decrypt smtp config: %v", err) return } smtp.Activate(configs.SMTP) } func startPeriodicChecks(duration time.Duration, isShared int) { startEventScheduler() spec := fmt.Sprintf("@every %s", duration) _, err := eventScheduler.AddFunc(spec, Connections.checkTransfers) util.PanicOnError(err) logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec) if isShared == 1 { logger.Info(logSender, "", "add reload configs task") _, err := eventScheduler.AddFunc("@every 10m", reloadProviderConfigs) util.PanicOnError(err) } if Config.IdleTimeout > 0 { ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval spec = fmt.Sprintf("@every %s", duration*ratio) _, err = eventScheduler.AddFunc(spec, Connections.checkIdles) util.PanicOnError(err) logger.Info(logSender, "", "scheduled idle connections check, schedule %q", spec) } } // ActiveTransfer defines the interface for the current active transfers type ActiveTransfer interface { GetID() int64 GetType() int GetSize() int64 GetDownloadedSize() int64 GetUploadedSize() int64 GetVirtualPath() string GetFsPath() string GetStartTime() time.Time SignalClose(err error) Truncate(fsPath string, size int64) (int64, error) GetRealFsPath(fsPath string) string SetTimes(fsPath string, atime time.Time, mtime time.Time) bool GetTruncatedSize() int64 HasSizeLimit() bool } // ActiveConnection defines the interface for the current active connections type ActiveConnection interface { GetID() string GetUsername() string GetRole() string GetMaxSessions() int GetLocalAddress() string GetRemoteAddress() string GetClientVersion() string GetProtocol() string GetConnectionTime() time.Time GetLastActivity() time.Time GetCommand() string Disconnect() error AddTransfer(t ActiveTransfer) RemoveTransfer(t ActiveTransfer) GetTransfers() []ConnectionTransfer SignalTransferClose(transferID int64, err error) CloseFS() error isAccessAllowed() bool } // StatAttributes defines the attributes for set stat commands type StatAttributes struct { Mode os.FileMode Atime time.Time Mtime time.Time UID int GID int Flags int Size int64 } // ConnectionTransfer defines the trasfer details type ConnectionTransfer struct { ID int64 `json:"-"` OperationType string `json:"operation_type"` StartTime int64 `json:"start_time"` Size int64 `json:"size"` VirtualPath string `json:"path"` HasSizeLimit bool `json:"-"` ULSize int64 `json:"-"` DLSize int64 `json:"-"` } // EventManagerConfig defines the configuration for the EventManager type EventManagerConfig struct { // EnabledCommands defines the system commands that can be executed via EventManager, // an empty list means that any command is allowed to be executed. // Commands must be set as an absolute path EnabledCommands []string `json:"enabled_commands" mapstructure:"enabled_commands"` } func (c *EventManagerConfig) validate() error { for _, c := range c.EnabledCommands { if !filepath.IsAbs(c) { return fmt.Errorf("invalid command %q: it must be an absolute path", c) } } return nil } // MetadataConfig defines how to handle metadata for cloud storage backends type MetadataConfig struct { // If not zero the metadata will be read before downloads and will be // available in notifications Read int `json:"read" mapstructure:"read"` } // Configuration defines configuration parameters common to all supported protocols type Configuration struct { // Maximum idle timeout as minutes. If a client is idle for a time that exceeds this setting it will be disconnected. // 0 means disabled IdleTimeout int `json:"idle_timeout" mapstructure:"idle_timeout"` // UploadMode 0 means standard, the files are uploaded directly to the requested path. // 1 means atomic: the files are uploaded to a temporary path and renamed to the requested path // when the client ends the upload. Atomic mode avoid problems such as a web server that // serves partial files when the files are being uploaded. // In atomic mode if there is an upload error the temporary file is deleted and so the requested // upload path will not contain a partial file. // 2 means atomic with resume support: as atomic but if there is an upload error the temporary // file is renamed to the requested path and not deleted, this way a client can reconnect and resume // the upload. // 4 means files for S3 backend are stored even if a client-side upload error is detected. // 8 means files for Google Cloud Storage backend are stored even if a client-side upload error is detected. // 16 means files for Azure Blob backend are stored even if a client-side upload error is detected. UploadMode int `json:"upload_mode" mapstructure:"upload_mode"` // Actions to execute for SFTP file operations and SSH commands Actions ProtocolActions `json:"actions" mapstructure:"actions"` // SetstatMode 0 means "normal mode": requests for changing permissions and owner/group are executed. // 1 means "ignore mode": requests for changing permissions and owner/group are silently ignored. // 2 means "ignore mode for cloud fs": requests for changing permissions and owner/group are // silently ignored for cloud based filesystem such as S3, GCS, Azure Blob. Requests for changing // modification times are ignored for cloud based filesystem if they are not supported. SetstatMode int `json:"setstat_mode" mapstructure:"setstat_mode"` // RenameMode defines how to handle directory renames. By default, renaming of non-empty directories // is not allowed for cloud storage providers (S3, GCS, Azure Blob). Set to 1 to enable recursive // renames for these providers, they may be slow, there is no atomic rename API like for local // filesystem, so SFTPGo will recursively list the directory contents and do a rename for each entry RenameMode int `json:"rename_mode" mapstructure:"rename_mode"` // ResumeMaxSize defines the maximum size allowed, in bytes, to resume uploads on storage backends // with immutable objects. By default, resuming uploads is not allowed for cloud storage providers // (S3, GCS, Azure Blob) because SFTPGo must rewrite the entire file. // Set to a value greater than 0 to allow resuming uploads of files smaller than or equal to the // defined size. ResumeMaxSize int64 `json:"resume_max_size" mapstructure:"resume_max_size"` // TempPath defines the path for temporary files such as those used for atomic uploads or file pipes. // If you set this option you must make sure that the defined path exists, is accessible for writing // by the user running SFTPGo, and is on the same filesystem as the users home directories otherwise // the renaming for atomic uploads will become a copy and therefore may take a long time. // The temporary files are not namespaced. The default is generally fine. Leave empty for the default. TempPath string `json:"temp_path" mapstructure:"temp_path"` // Support for HAProxy PROXY protocol. // If you are running SFTPGo behind a proxy server such as HAProxy, AWS ELB or NGNIX, you can enable // the proxy protocol. It provides a convenient way to safely transport connection information // such as a client's address across multiple layers of NAT or TCP proxies to get the real // client IP address instead of the proxy IP. Both protocol versions 1 and 2 are supported. // - 0 means disabled // - 1 means proxy protocol enabled. Proxy header will be used and requests without proxy header will be accepted. // - 2 means proxy protocol required. Proxy header will be used and requests without proxy header will be rejected. // If the proxy protocol is enabled in SFTPGo then you have to enable the protocol in your proxy configuration too, // for example for HAProxy add "send-proxy" or "send-proxy-v2" to each server configuration line. ProxyProtocol int `json:"proxy_protocol" mapstructure:"proxy_protocol"` // List of IP addresses and IP ranges allowed to send the proxy header. // If proxy protocol is set to 1 and we receive a proxy header from an IP that is not in the list then the // connection will be accepted and the header will be ignored. // If proxy protocol is set to 2 and we receive a proxy header from an IP that is not in the list then the // connection will be rejected. ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"` // List of IP addresses and IP ranges for which not to read the proxy header ProxySkipped []string `json:"proxy_skipped" mapstructure:"proxy_skipped"` // Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts. // If you define an HTTP URL it will be invoked using a `GET` request. // Please note that SFTPGo services may not yet be available when this hook is run. // Leave empty do disable. StartupHook string `json:"startup_hook" mapstructure:"startup_hook"` // Absolute path to an external program or an HTTP URL to invoke after a user connects // and before he tries to login. It allows you to reject the connection based on the source // ip address. Leave empty do disable. PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"` // Absolute path to an external program or an HTTP URL to invoke after an SSH/FTP connection ends. // Leave empty do disable. PostDisconnectHook string `json:"post_disconnect_hook" mapstructure:"post_disconnect_hook"` // Maximum number of concurrent client connections. 0 means unlimited MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"` // Maximum number of concurrent client connections from the same host (IP). 0 means unlimited MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"` // Defines the status of the global allow list. 0 means disabled, 1 enabled. // If enabled, only the listed IPs/networks can access the configured services, all other // client connections will be dropped before they even try to authenticate. // Ensure to enable this setting only after adding some allowed ip/networks from the WebAdmin/REST API AllowListStatus int `json:"allowlist_status" mapstructure:"allowlist_status"` // Allow users on this instance to use other users/virtual folders on this instance as storage backend. // Enable this setting if you know what you are doing. AllowSelfConnections int `json:"allow_self_connections" mapstructure:"allow_self_connections"` // Defender configuration DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"` // Rate limiter configurations RateLimitersConfig []RateLimiterConfig `json:"rate_limiters" mapstructure:"rate_limiters"` // Umask for new uploads. Leave blank to use the system default. Umask string `json:"umask" mapstructure:"umask"` // Defines the server version ServerVersion string `json:"server_version" mapstructure:"server_version"` // TZ defines the time zone to use for the EventManager scheduler and to // control time-based access restrictions. Set to "local" to use the // server's local time, otherwise UTC will be used. TZ string `json:"tz" mapstructure:"tz"` // Metadata configuration Metadata MetadataConfig `json:"metadata" mapstructure:"metadata"` // EventManager configuration EventManager EventManagerConfig `json:"event_manager" mapstructure:"event_manager"` idleTimeoutAsDuration time.Duration idleLoginTimeout time.Duration defender Defender allowList *dataprovider.IPList rateLimitersList *dataprovider.IPList proxyAllowed []func(net.IP) bool proxySkipped []func(net.IP) bool } // IsAtomicUploadEnabled returns true if atomic upload is enabled func (c *Configuration) IsAtomicUploadEnabled() bool { return c.UploadMode&UploadModeAtomic != 0 || c.UploadMode&UploadModeAtomicWithResume != 0 } func (c *Configuration) initializeProxyProtocol() error { if c.ProxyProtocol > 0 { allowed, err := util.ParseAllowedIPAndRanges(c.ProxyAllowed) if err != nil { return fmt.Errorf("invalid proxy allowed: %w", err) } skipped, err := util.ParseAllowedIPAndRanges(c.ProxySkipped) if err != nil { return fmt.Errorf("invalid proxy skipped: %w", err) } Config.proxyAllowed = allowed Config.proxySkipped = skipped } return nil } // GetProxyListener returns a wrapper for the given listener that supports the // HAProxy Proxy Protocol func (c *Configuration) GetProxyListener(listener net.Listener) (net.Listener, error) { if c.ProxyProtocol > 0 { defaultPolicy := proxyproto.REQUIRE if c.ProxyProtocol == 1 { defaultPolicy = proxyproto.IGNORE } return &proxyproto.Listener{ Listener: listener, ConnPolicy: getProxyPolicy(c.proxyAllowed, c.proxySkipped, defaultPolicy), ReadHeaderTimeout: 10 * time.Second, }, nil } return nil, errors.New("proxy protocol not configured") } // GetRateLimitersStatus returns the rate limiters status func (c *Configuration) GetRateLimitersStatus() (bool, []string) { enabled := false var protocols []string for _, rlCfg := range c.RateLimitersConfig { if rlCfg.isEnabled() { enabled = true protocols = append(protocols, rlCfg.Protocols...) } } return enabled, util.RemoveDuplicates(protocols, false) } // IsAllowListEnabled returns true if the global allow list is enabled func (c *Configuration) IsAllowListEnabled() bool { return c.AllowListStatus > 0 } // ExecuteStartupHook runs the startup hook if defined func (c *Configuration) ExecuteStartupHook() error { if c.StartupHook == "" { return nil } if strings.HasPrefix(c.StartupHook, "http") { var url *url.URL url, err := url.Parse(c.StartupHook) if err != nil { logger.Warn(logSender, "", "Invalid startup hook %q: %v", c.StartupHook, err) return err } startTime := time.Now() resp, err := httpclient.RetryableGet(url.String()) if err != nil { logger.Warn(logSender, "", "Error executing startup hook: %v", err) return err } defer resp.Body.Close() logger.Debug(logSender, "", "Startup hook executed, elapsed: %v, response code: %v", time.Since(startTime), resp.StatusCode) return nil } if !filepath.IsAbs(c.StartupHook) { err := fmt.Errorf("invalid startup hook %q", c.StartupHook) logger.Warn(logSender, "", "Invalid startup hook %q", c.StartupHook) return err } startTime := time.Now() timeout, env, args := command.GetConfig(c.StartupHook, command.HookStartup) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, c.StartupHook, args...) cmd.Env = env err := cmd.Run() logger.Debug(logSender, "", "Startup hook executed, elapsed: %s, error: %v", time.Since(startTime), err) return nil } func (c *Configuration) executePostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) { startNewHook() defer hookEnded() ipAddr := util.GetIPFromRemoteAddress(remoteAddr) connDuration := int64(time.Since(connectionTime) / time.Millisecond) if strings.HasPrefix(c.PostDisconnectHook, "http") { var url *url.URL url, err := url.Parse(c.PostDisconnectHook) if err != nil { logger.Warn(protocol, connID, "Invalid post disconnect hook %q: %v", c.PostDisconnectHook, err) return } q := url.Query() q.Add("ip", ipAddr) q.Add("protocol", protocol) q.Add("username", username) q.Add("connection_duration", strconv.FormatInt(connDuration, 10)) url.RawQuery = q.Encode() startTime := time.Now() resp, err := httpclient.RetryableGet(url.String()) respCode := 0 if err == nil { respCode = resp.StatusCode resp.Body.Close() } logger.Debug(protocol, connID, "Post disconnect hook response code: %v, elapsed: %v, err: %v", respCode, time.Since(startTime), err) return } if !filepath.IsAbs(c.PostDisconnectHook) { logger.Debug(protocol, connID, "invalid post disconnect hook %q", c.PostDisconnectHook) return } timeout, env, args := command.GetConfig(c.PostDisconnectHook, command.HookPostDisconnect) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() startTime := time.Now() cmd := exec.CommandContext(ctx, c.PostDisconnectHook, args...) cmd.Env = append(env, fmt.Sprintf("SFTPGO_CONNECTION_IP=%s", ipAddr), fmt.Sprintf("SFTPGO_CONNECTION_USERNAME=%s", username), fmt.Sprintf("SFTPGO_CONNECTION_DURATION=%d", connDuration), fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%s", protocol)) err := cmd.Run() logger.Debug(protocol, connID, "Post disconnect hook executed, elapsed: %s error: %v", time.Since(startTime), err) } func (c *Configuration) checkPostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) { if c.PostDisconnectHook == "" { return } if !slices.Contains(disconnHookProtocols, protocol) { return } go c.executePostDisconnectHook(remoteAddr, protocol, username, connID, connectionTime) } // ExecutePostConnectHook executes the post connect hook if defined func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error { if c.PostConnectHook == "" { return nil } if strings.HasPrefix(c.PostConnectHook, "http") { var url *url.URL url, err := url.Parse(c.PostConnectHook) if err != nil { logger.Warn(protocol, "", "Login from ip %q denied, invalid post connect hook %q: %v", ipAddr, c.PostConnectHook, err) return getPermissionDeniedError(protocol) } q := url.Query() q.Add("ip", ipAddr) q.Add("protocol", protocol) url.RawQuery = q.Encode() resp, err := httpclient.RetryableGet(url.String()) if err != nil { logger.Warn(protocol, "", "Login from ip %q denied, error executing post connect hook: %v", ipAddr, err) return getPermissionDeniedError(protocol) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.Warn(protocol, "", "Login from ip %q denied, post connect hook response code: %v", ipAddr, resp.StatusCode) return getPermissionDeniedError(protocol) } return nil } if !filepath.IsAbs(c.PostConnectHook) { err := fmt.Errorf("invalid post connect hook %q", c.PostConnectHook) logger.Warn(protocol, "", "Login from ip %q denied: %v", ipAddr, err) return getPermissionDeniedError(protocol) } timeout, env, args := command.GetConfig(c.PostConnectHook, command.HookPostConnect) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, c.PostConnectHook, args...) cmd.Env = append(env, fmt.Sprintf("SFTPGO_CONNECTION_IP=%s", ipAddr), fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%s", protocol)) err := cmd.Run() if err != nil { logger.Warn(protocol, "", "Login from ip %q denied, connect hook error: %v", ipAddr, err) return getPermissionDeniedError(protocol) } return nil } func getProxyPolicy(allowed, skipped []func(net.IP) bool, def proxyproto.Policy) proxyproto.ConnPolicyFunc { return func(connPolicyOptions proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { upstreamIP, err := util.GetIPFromNetAddr(connPolicyOptions.Upstream) if err != nil { // Something is wrong with the source IP, better reject the // connection. logger.Error(logSender, "", "reject connection from ip %q, err: %v", connPolicyOptions.Upstream, err) return proxyproto.REJECT, proxyproto.ErrInvalidUpstream } for _, skippedFrom := range skipped { if skippedFrom(upstreamIP) { return proxyproto.SKIP, nil } } for _, allowFrom := range allowed { if allowFrom(upstreamIP) { if def == proxyproto.REQUIRE { return proxyproto.REQUIRE, nil } return proxyproto.USE, nil } } if def == proxyproto.REQUIRE { logger.Debug(logSender, "", "reject connection from ip %q: proxy protocol signature required and not set", upstreamIP) return proxyproto.REJECT, proxyproto.ErrInvalidUpstream } return def, nil } } // SSHConnection defines an ssh connection. // Each SSH connection can open several channels for SFTP or SSH commands type SSHConnection struct { id string conn io.Closer lastActivity atomic.Int64 } // NewSSHConnection returns a new SSHConnection func NewSSHConnection(id string, conn io.Closer) *SSHConnection { c := &SSHConnection{ id: id, conn: conn, } c.lastActivity.Store(time.Now().UnixNano()) return c } // GetID returns the ID for this SSHConnection func (c *SSHConnection) GetID() string { return c.id } // UpdateLastActivity updates last activity for this connection func (c *SSHConnection) UpdateLastActivity() { c.lastActivity.Store(time.Now().UnixNano()) } // GetLastActivity returns the last connection activity func (c *SSHConnection) GetLastActivity() time.Time { return time.Unix(0, c.lastActivity.Load()) } // Close closes the underlying network connection func (c *SSHConnection) Close() error { return c.conn.Close() } // ActiveConnections holds the currect active connections with the associated transfers type ActiveConnections struct { // clients contains both authenticated and estabilished connections and the ones waiting // for authentication clients clientsMap // transfers contains active transfers, total and per-user transfers clientsMap transfersCheckStatus atomic.Bool sync.RWMutex connections []ActiveConnection mapping map[string]int sshConnections []*SSHConnection sshMapping map[string]int perUserConns map[string]int } // internal method, must be called within a locked block func (conns *ActiveConnections) addUserConnection(username string) { if username == "" { return } conns.perUserConns[username]++ } // internal method, must be called within a locked block func (conns *ActiveConnections) removeUserConnection(username string) { if username == "" { return } if val, ok := conns.perUserConns[username]; ok { conns.perUserConns[username]-- if val > 1 { return } delete(conns.perUserConns, username) } } // GetActiveSessions returns the number of active sessions for the given username. // We return the open sessions for any protocol func (conns *ActiveConnections) GetActiveSessions(username string) int { conns.RLock() defer conns.RUnlock() return conns.perUserConns[username] } // Add adds a new connection to the active ones func (conns *ActiveConnections) Add(c ActiveConnection) error { conns.Lock() defer conns.Unlock() if username := c.GetUsername(); username != "" { if maxSessions := c.GetMaxSessions(); maxSessions > 0 { if val := conns.perUserConns[username]; val >= maxSessions { return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions) } if val := conns.transfers.getTotalFrom(username); val >= maxSessions { return fmt.Errorf("too many open transfers: %d/%d", val, maxSessions) } } conns.addUserConnection(username) } conns.mapping[c.GetID()] = len(conns.connections) conns.connections = append(conns.connections, c) metric.UpdateActiveConnectionsSize(len(conns.connections)) logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %q, remote address %q, num open connections: %d", c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections)) return nil } // Swap replaces an existing connection with the given one. // This method is useful if you have to change some connection details // for example for FTP is used to update the connection once the user // authenticates func (conns *ActiveConnections) Swap(c ActiveConnection) error { conns.Lock() defer conns.Unlock() if idx, ok := conns.mapping[c.GetID()]; ok { conn := conns.connections[idx] conns.removeUserConnection(conn.GetUsername()) if username := c.GetUsername(); username != "" { if maxSessions := c.GetMaxSessions(); maxSessions > 0 { if val, ok := conns.perUserConns[username]; ok && val >= maxSessions { conns.addUserConnection(conn.GetUsername()) return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions) } } conns.addUserConnection(username) } err := conn.CloseFS() conns.connections[idx] = c logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err) conn = nil return nil } return errors.New("connection to swap not found") } // Remove removes a connection from the active ones func (conns *ActiveConnections) Remove(connectionID string) { conns.Lock() defer conns.Unlock() if idx, ok := conns.mapping[connectionID]; ok { conn := conns.connections[idx] err := conn.CloseFS() lastIdx := len(conns.connections) - 1 conns.connections[idx] = conns.connections[lastIdx] conns.connections[lastIdx] = nil conns.connections = conns.connections[:lastIdx] delete(conns.mapping, connectionID) if idx != lastIdx { conns.mapping[conns.connections[idx].GetID()] = idx } conns.removeUserConnection(conn.GetUsername()) metric.UpdateActiveConnectionsSize(lastIdx) logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %q, remote address %q close fs error: %v, num open connections: %d", conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx) if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" && !slices.Contains(ftpLoginCommands, conn.GetCommand()) { ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress()) logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTried, ProtocolFTP, dataprovider.ErrNoAuthTried.Error()) metric.AddNoAuthTried() AddDefenderEvent(ip, ProtocolFTP, HostEventNoLoginTried) dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTried, ip, ProtocolFTP, dataprovider.ErrNoAuthTried) plugin.Handler.NotifyLogEvent(notifier.LogEventTypeNoLoginTried, ProtocolFTP, "", ip, "", dataprovider.ErrNoAuthTried) } Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(), conn.GetID(), conn.GetConnectionTime()) return } logger.Debug(logSender, "", "connection id %q to remove not found!", connectionID) } // Close closes an active connection. // It returns true on success func (conns *ActiveConnections) Close(connectionID, role string) bool { conns.RLock() var result bool if idx, ok := conns.mapping[connectionID]; ok { c := conns.connections[idx] if role == "" || c.GetRole() == role { defer func(conn ActiveConnection) { err := conn.Disconnect() logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err) }(c) result = true } } conns.RUnlock() return result } // AddSSHConnection adds a new ssh connection to the active ones func (conns *ActiveConnections) AddSSHConnection(c *SSHConnection) { conns.Lock() defer conns.Unlock() conns.sshMapping[c.GetID()] = len(conns.sshConnections) conns.sshConnections = append(conns.sshConnections, c) logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %d", len(conns.sshConnections)) } // RemoveSSHConnection removes a connection from the active ones func (conns *ActiveConnections) RemoveSSHConnection(connectionID string) { conns.Lock() defer conns.Unlock() if idx, ok := conns.sshMapping[connectionID]; ok { lastIdx := len(conns.sshConnections) - 1 conns.sshConnections[idx] = conns.sshConnections[lastIdx] conns.sshConnections[lastIdx] = nil conns.sshConnections = conns.sshConnections[:lastIdx] delete(conns.sshMapping, connectionID) if idx != lastIdx { conns.sshMapping[conns.sshConnections[idx].GetID()] = idx } logger.Debug(logSender, connectionID, "ssh connection removed, num open ssh connections: %d", lastIdx) return } logger.Warn(logSender, "", "ssh connection to remove with id %q not found!", connectionID) } func (conns *ActiveConnections) checkIdles() { conns.RLock() for _, sshConn := range conns.sshConnections { idleTime := time.Since(sshConn.GetLastActivity()) if idleTime > Config.idleTimeoutAsDuration { // we close an SSH connection if it has no active connections associated idToMatch := fmt.Sprintf("_%s_", sshConn.GetID()) toClose := true for _, conn := range conns.connections { if strings.Contains(conn.GetID(), idToMatch) { if time.Since(conn.GetLastActivity()) <= Config.idleTimeoutAsDuration { toClose = false break } } } if toClose { defer func(c *SSHConnection) { err := c.Close() logger.Debug(logSender, c.GetID(), "close idle SSH connection, idle time: %v, close err: %v", time.Since(c.GetLastActivity()), err) }(sshConn) } } } for _, c := range conns.connections { idleTime := time.Since(c.GetLastActivity()) isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && c.GetUsername() == "") if idleTime > Config.idleTimeoutAsDuration || (isUnauthenticatedFTPUser && idleTime > Config.idleLoginTimeout) { defer func(conn ActiveConnection) { err := conn.Disconnect() logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %s, username: %q close err: %v", time.Since(conn.GetLastActivity()), conn.GetUsername(), err) }(c) } else if !isUnauthenticatedFTPUser && !c.isAccessAllowed() { defer func(conn ActiveConnection) { err := conn.Disconnect() logger.Info(conn.GetProtocol(), conn.GetID(), "access conditions not met for user: %q close connection err: %v", conn.GetUsername(), err) }(c) } } conns.RUnlock() } func (conns *ActiveConnections) checkTransfers() { if conns.transfersCheckStatus.Load() { logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution") return } conns.transfersCheckStatus.Store(true) defer conns.transfersCheckStatus.Store(false) conns.RLock() if len(conns.connections) < 2 { conns.RUnlock() return } var wg sync.WaitGroup logger.Debug(logSender, "", "start concurrent transfers check") // update the current size for transfers to monitors for _, c := range conns.connections { for _, t := range c.GetTransfers() { if t.HasSizeLimit { wg.Add(1) go func(transfer ConnectionTransfer, connID string) { defer wg.Done() transfersChecker.UpdateTransferCurrentSizes(transfer.ULSize, transfer.DLSize, transfer.ID, connID) }(t, c.GetID()) } } } conns.RUnlock() logger.Debug(logSender, "", "waiting for the update of the transfers current size") wg.Wait() logger.Debug(logSender, "", "getting overquota transfers") overquotaTransfers := transfersChecker.GetOverquotaTransfers() logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers)) if len(overquotaTransfers) == 0 { return } conns.RLock() defer conns.RUnlock() for _, c := range conns.connections { for _, overquotaTransfer := range overquotaTransfers { if c.GetID() == overquotaTransfer.ConnID { logger.Info(logSender, c.GetID(), "user %q is overquota, try to close transfer id %v", c.GetUsername(), overquotaTransfer.TransferID) var err error if overquotaTransfer.TransferType == TransferDownload { err = getReadQuotaExceededError(c.GetProtocol()) } else { err = getQuotaExceededError(c.GetProtocol()) } c.SignalTransferClose(overquotaTransfer.TransferID, err) } } } logger.Debug(logSender, "", "transfers check completed") } // AddClientConnection stores a new client connection func (conns *ActiveConnections) AddClientConnection(ipAddr string) { conns.clients.add(ipAddr) } // RemoveClientConnection removes a disconnected client from the tracked ones func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) { conns.clients.remove(ipAddr) } // GetClientConnections returns the total number of client connections func (conns *ActiveConnections) GetClientConnections() int32 { return conns.clients.getTotal() } // GetTotalTransfers returns the total number of active transfers func (conns *ActiveConnections) GetTotalTransfers() int32 { return conns.transfers.getTotal() } // IsNewTransferAllowed returns an error if the maximum number of concurrent allowed // transfers is exceeded func (conns *ActiveConnections) IsNewTransferAllowed(username string) error { if isShuttingDown.Load() { return ErrShuttingDown } if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { return nil } if Config.MaxPerHostConnections > 0 { if transfers := conns.transfers.getTotalFrom(username); transfers >= Config.MaxPerHostConnections { logger.Info(logSender, "", "active transfers from user %q: %d/%d", username, transfers, Config.MaxPerHostConnections) return ErrConnectionDenied } } if Config.MaxTotalConnections > 0 { if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) { logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections) return ErrConnectionDenied } } return nil } // IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed // connections is exceeded or a whitelist is defined and the specified ipAddr is not listed // or the service is shutting down func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr, protocol string) error { if isShuttingDown.Load() { return ErrShuttingDown } if Config.allowList != nil { isListed, _, err := Config.allowList.IsListed(ipAddr, protocol) if err != nil { logger.Error(logSender, "", "unable to query allow list, connection denied, ip %q, protocol %s, err: %v", ipAddr, protocol, err) return ErrConnectionDenied } if !isListed { return ErrConnectionDenied } } if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { return nil } if Config.MaxPerHostConnections > 0 { if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections { if !AddDefenderEvent(ipAddr, protocol, HostEventLimitExceeded) { logger.Warn(logSender, "", "connection denied, active connections from IP %q: %d/%d", ipAddr, total, Config.MaxPerHostConnections) return ErrConnectionDenied } logger.Info(logSender, "", "active connections from safe IP %q: %d", ipAddr, total) } } if Config.MaxTotalConnections > 0 { if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) { logger.Info(logSender, "", "active client connections %d/%d", total, Config.MaxTotalConnections) return ErrConnectionDenied } // on a single SFTP connection we could have multiple SFTP channels or commands // so we check the estabilished connections and active uploads too if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) { logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections) return ErrConnectionDenied } conns.RLock() defer conns.RUnlock() if sess := len(conns.connections); sess >= Config.MaxTotalConnections { logger.Info(logSender, "", "active client sessions %d/%d", sess, Config.MaxTotalConnections) return ErrConnectionDenied } } return nil } // GetStats returns stats for active connections func (conns *ActiveConnections) GetStats(role string) []ConnectionStatus { conns.RLock() defer conns.RUnlock() stats := make([]ConnectionStatus, 0, len(conns.connections)) node := dataprovider.GetNodeName() for _, c := range conns.connections { if role == "" || c.GetRole() == role { stat := ConnectionStatus{ Username: c.GetUsername(), ConnectionID: c.GetID(), ClientVersion: c.GetClientVersion(), RemoteAddress: c.GetRemoteAddress(), ConnectionTime: util.GetTimeAsMsSinceEpoch(c.GetConnectionTime()), LastActivity: util.GetTimeAsMsSinceEpoch(c.GetLastActivity()), CurrentTime: util.GetTimeAsMsSinceEpoch(time.Now()), Protocol: c.GetProtocol(), Command: c.GetCommand(), Transfers: c.GetTransfers(), Node: node, } stats = append(stats, stat) } } return stats } // ConnectionStatus returns the status for an active connection type ConnectionStatus struct { // Logged in username Username string `json:"username"` // Unique identifier for the connection ConnectionID string `json:"connection_id"` // client's version string ClientVersion string `json:"client_version,omitempty"` // Remote address for this connection RemoteAddress string `json:"remote_address"` // Connection time as unix timestamp in milliseconds ConnectionTime int64 `json:"connection_time"` // Last activity as unix timestamp in milliseconds LastActivity int64 `json:"last_activity"` // Current time as unix timestamp in milliseconds CurrentTime int64 `json:"current_time"` // Protocol for this connection Protocol string `json:"protocol"` // active uploads/downloads Transfers []ConnectionTransfer `json:"active_transfers,omitempty"` // SSH command or WebDAV method Command string `json:"command,omitempty"` // Node identifier, omitted for single node installations Node string `json:"node,omitempty"` } // ActiveQuotaScan defines an active quota scan for a user type ActiveQuotaScan struct { // Username to which the quota scan refers Username string `json:"username"` // quota scan start time as unix timestamp in milliseconds StartTime int64 `json:"start_time"` Role string `json:"-"` } // ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder type ActiveVirtualFolderQuotaScan struct { // folder name to which the quota scan refers Name string `json:"name"` // quota scan start time as unix timestamp in milliseconds StartTime int64 `json:"start_time"` } // ActiveScans holds the active quota scans type ActiveScans struct { sync.RWMutex UserScans []ActiveQuotaScan FolderScans []ActiveVirtualFolderQuotaScan } // GetUsersQuotaScans returns the active users quota scans func (s *ActiveScans) GetUsersQuotaScans(role string) []ActiveQuotaScan { s.RLock() defer s.RUnlock() scans := make([]ActiveQuotaScan, 0, len(s.UserScans)) for _, scan := range s.UserScans { if role == "" || role == scan.Role { scans = append(scans, ActiveQuotaScan{ Username: scan.Username, StartTime: scan.StartTime, }) } } return scans } // AddUserQuotaScan adds a user to the ones with active quota scans. // Returns false if the user has a quota scan already running func (s *ActiveScans) AddUserQuotaScan(username, role string) bool { s.Lock() defer s.Unlock() for _, scan := range s.UserScans { if scan.Username == username { return false } } s.UserScans = append(s.UserScans, ActiveQuotaScan{ Username: username, StartTime: util.GetTimeAsMsSinceEpoch(time.Now()), Role: role, }) return true } // RemoveUserQuotaScan removes a user from the ones with active quota scans. // Returns false if the user has no active quota scans func (s *ActiveScans) RemoveUserQuotaScan(username string) bool { s.Lock() defer s.Unlock() for idx, scan := range s.UserScans { if scan.Username == username { lastIdx := len(s.UserScans) - 1 s.UserScans[idx] = s.UserScans[lastIdx] s.UserScans = s.UserScans[:lastIdx] return true } } return false } // GetVFoldersQuotaScans returns the active quota scans for virtual folders func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan { s.RLock() defer s.RUnlock() scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans)) copy(scans, s.FolderScans) return scans } // AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans. // Returns false if the folder has a quota scan already running func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool { s.Lock() defer s.Unlock() for _, scan := range s.FolderScans { if scan.Name == folderName { return false } } s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{ Name: folderName, StartTime: util.GetTimeAsMsSinceEpoch(time.Now()), }) return true } // RemoveVFolderQuotaScan removes a folder from the ones with active quota scans. // Returns false if the folder has no active quota scans func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool { s.Lock() defer s.Unlock() for idx, scan := range s.FolderScans { if scan.Name == folderName { lastIdx := len(s.FolderScans) - 1 s.FolderScans[idx] = s.FolderScans[lastIdx] s.FolderScans = s.FolderScans[:lastIdx] return true } } return false } ================================================ FILE: internal/common/common_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "crypto/tls" "encoding/json" "fmt" "net" "os" "os/exec" "path/filepath" "runtime" "slices" "sync" "testing" "time" "github.com/alexedwards/argon2id" "github.com/pires/go-proxyproto" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( logSenderTest = "common_test" httpAddr = "127.0.0.1:9999" osWindows = "windows" userTestUsername = "common_test_username" ) var ( configDir = filepath.Join(".", "..", "..") ) type fakeConnection struct { *BaseConnection command string } func (c *fakeConnection) AddUser(user dataprovider.User) error { _, err := user.GetFilesystem(c.GetID()) if err != nil { return err } c.User = user return nil } func (c *fakeConnection) Disconnect() error { Connections.Remove(c.GetID()) return nil } func (c *fakeConnection) GetClientVersion() string { return "" } func (c *fakeConnection) GetCommand() string { return c.command } func (c *fakeConnection) GetLocalAddress() string { return "" } func (c *fakeConnection) GetRemoteAddress() string { return "" } type customNetConn struct { net.Conn id string isClosed bool } func (c *customNetConn) Close() error { Connections.RemoveSSHConnection(c.id) c.isClosed = true return c.Conn.Close() } func TestConnections(t *testing.T) { c1 := &fakeConnection{ BaseConnection: NewBaseConnection("id1", ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, }, }), } c2 := &fakeConnection{ BaseConnection: NewBaseConnection("id2", ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, }, }), } c3 := &fakeConnection{ BaseConnection: NewBaseConnection("id3", ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, }, }), } c4 := &fakeConnection{ BaseConnection: NewBaseConnection("id4", ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, }, }), } assert.Equal(t, "SFTP_id1", c1.GetID()) assert.Equal(t, "SFTP_id2", c2.GetID()) assert.Equal(t, "SFTP_id3", c3.GetID()) assert.Equal(t, "SFTP_id4", c4.GetID()) err := Connections.Add(c1) assert.NoError(t, err) err = Connections.Add(c2) assert.NoError(t, err) err = Connections.Add(c3) assert.NoError(t, err) err = Connections.Add(c4) assert.NoError(t, err) Connections.RLock() assert.Len(t, Connections.connections, 4) assert.Len(t, Connections.mapping, 4) _, ok := Connections.mapping[c1.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.mapping[c1.GetID()]) assert.Equal(t, 1, Connections.mapping[c2.GetID()]) assert.Equal(t, 2, Connections.mapping[c3.GetID()]) assert.Equal(t, 3, Connections.mapping[c4.GetID()]) Connections.RUnlock() c2 = &fakeConnection{ BaseConnection: NewBaseConnection("id2", ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername + "_mod", }, }), } err = Connections.Swap(c2) assert.NoError(t, err) Connections.RLock() assert.Len(t, Connections.connections, 4) assert.Len(t, Connections.mapping, 4) _, ok = Connections.mapping[c1.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.mapping[c1.GetID()]) assert.Equal(t, 1, Connections.mapping[c2.GetID()]) assert.Equal(t, 2, Connections.mapping[c3.GetID()]) assert.Equal(t, 3, Connections.mapping[c4.GetID()]) assert.Equal(t, userTestUsername+"_mod", Connections.connections[1].GetUsername()) Connections.RUnlock() Connections.Remove(c2.GetID()) Connections.RLock() assert.Len(t, Connections.connections, 3) assert.Len(t, Connections.mapping, 3) _, ok = Connections.mapping[c1.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.mapping[c1.GetID()]) assert.Equal(t, 1, Connections.mapping[c4.GetID()]) assert.Equal(t, 2, Connections.mapping[c3.GetID()]) Connections.RUnlock() Connections.Remove(c3.GetID()) Connections.RLock() assert.Len(t, Connections.connections, 2) assert.Len(t, Connections.mapping, 2) _, ok = Connections.mapping[c1.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.mapping[c1.GetID()]) assert.Equal(t, 1, Connections.mapping[c4.GetID()]) Connections.RUnlock() Connections.Remove(c1.GetID()) Connections.RLock() assert.Len(t, Connections.connections, 1) assert.Len(t, Connections.mapping, 1) _, ok = Connections.mapping[c4.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.mapping[c4.GetID()]) Connections.RUnlock() Connections.Remove(c4.GetID()) Connections.RLock() assert.Len(t, Connections.connections, 0) assert.Len(t, Connections.mapping, 0) Connections.RUnlock() } func TestEventManagerCommandsInitialization(t *testing.T) { configCopy := Config c := Configuration{ EventManager: EventManagerConfig{ EnabledCommands: []string{"ls"}, // not an absolute path }, } err := Initialize(c, 0) assert.ErrorContains(t, err, "invalid command") var commands []string if runtime.GOOS == osWindows { commands = []string{"C:\\command"} } else { commands = []string{"/bin/ls"} } c.EventManager.EnabledCommands = commands err = Initialize(c, 0) assert.NoError(t, err) assert.Equal(t, commands, dataprovider.EnabledActionCommands) dataprovider.EnabledActionCommands = configCopy.EventManager.EnabledCommands Config = configCopy } func TestInitializationProxyErrors(t *testing.T) { configCopy := Config c := Configuration{ ProxyProtocol: 1, ProxyAllowed: []string{"1.1.1.1111"}, } err := Initialize(c, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid proxy allowed") } c.ProxyAllowed = nil c.ProxySkipped = []string{"invalid"} err = Initialize(c, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid proxy skipped") } c.ProxyAllowed = []string{"1.1.1.1"} c.ProxySkipped = []string{"2.2.2.2", "10.8.0.0/24"} err = Initialize(c, 0) assert.NoError(t, err) assert.Len(t, Config.proxyAllowed, 1) assert.Len(t, Config.proxySkipped, 2) Config = configCopy assert.Equal(t, 0, Config.ProxyProtocol) assert.Len(t, Config.proxyAllowed, 0) assert.Len(t, Config.proxySkipped, 0) } func TestInitializationClosedProvider(t *testing.T) { configCopy := Config providerConf := dataprovider.GetProviderConfig() err := dataprovider.Close() assert.NoError(t, err) config := Configuration{ AllowListStatus: 1, } err = Initialize(config, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to initialize the allow list") } config.AllowListStatus = 0 config.RateLimitersConfig = []RateLimiterConfig{ { Average: 100, Period: 1000, Burst: 5, Type: int(rateLimiterTypeGlobal), Protocols: rateLimiterProtocolValues, }, } err = Initialize(config, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to initialize ratelimiters list") } config.RateLimitersConfig = nil config.DefenderConfig = DefenderConfig{ Enabled: true, Driver: DefenderDriverProvider, BanTime: 10, BanTimeIncrement: 50, Threshold: 10, ScoreInvalid: 2, ScoreValid: 1, ScoreNoAuth: 2, ObservationTime: 15, EntriesSoftLimit: 100, EntriesHardLimit: 150, } err = Initialize(config, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "defender initialization error") } config.DefenderConfig.Driver = DefenderDriverMemory err = Initialize(config, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "defender initialization error") } err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) Config = configCopy } func TestSSHConnections(t *testing.T) { conn1, conn2 := net.Pipe() now := time.Now() sshConn1 := NewSSHConnection("id1", conn1) sshConn2 := NewSSHConnection("id2", conn2) sshConn3 := NewSSHConnection("id3", conn2) assert.Equal(t, "id1", sshConn1.GetID()) assert.Equal(t, "id2", sshConn2.GetID()) assert.Equal(t, "id3", sshConn3.GetID()) sshConn1.UpdateLastActivity() assert.GreaterOrEqual(t, sshConn1.GetLastActivity().UnixNano(), now.UnixNano()) Connections.AddSSHConnection(sshConn1) Connections.AddSSHConnection(sshConn2) Connections.AddSSHConnection(sshConn3) Connections.RLock() assert.Len(t, Connections.sshConnections, 3) _, ok := Connections.sshMapping[sshConn1.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.sshMapping[sshConn1.GetID()]) assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()]) assert.Equal(t, 2, Connections.sshMapping[sshConn3.GetID()]) Connections.RUnlock() Connections.RemoveSSHConnection(sshConn1.id) Connections.RLock() assert.Len(t, Connections.sshConnections, 2) assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id) _, ok = Connections.sshMapping[sshConn3.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()]) assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()]) Connections.RUnlock() Connections.RemoveSSHConnection(sshConn1.id) Connections.RLock() assert.Len(t, Connections.sshConnections, 2) assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id) _, ok = Connections.sshMapping[sshConn3.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()]) assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()]) Connections.RUnlock() Connections.RemoveSSHConnection(sshConn2.id) Connections.RLock() assert.Len(t, Connections.sshConnections, 1) assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) _, ok = Connections.sshMapping[sshConn3.GetID()] assert.True(t, ok) assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()]) Connections.RUnlock() Connections.RemoveSSHConnection(sshConn3.id) Connections.RLock() assert.Len(t, Connections.sshConnections, 0) assert.Len(t, Connections.sshMapping, 0) Connections.RUnlock() assert.NoError(t, sshConn1.Close()) assert.NoError(t, sshConn2.Close()) assert.NoError(t, sshConn3.Close()) } func TestDefenderIntegration(t *testing.T) { // by default defender is nil configCopy := Config wdPath, err := os.Getwd() require.NoError(t, err) pluginsConfig := []plugin.Config{ { Type: "ipfilter", Cmd: filepath.Join(wdPath, "..", "..", "tests", "ipfilter", "ipfilter"), AutoMTLS: true, }, } if runtime.GOOS == osWindows { pluginsConfig[0].Cmd += ".exe" } err = plugin.Initialize(pluginsConfig, "debug") require.NoError(t, err) ip := "127.1.1.1" assert.Nil(t, Reload()) // 192.168.1.12 is banned from the ipfilter plugin assert.True(t, IsBanned("192.168.1.12", ProtocolFTP)) AddDefenderEvent(ip, ProtocolFTP, HostEventNoLoginTried) assert.False(t, IsBanned(ip, ProtocolFTP)) banTime, err := GetDefenderBanTime(ip) assert.NoError(t, err) assert.Nil(t, banTime) assert.False(t, DeleteDefenderHost(ip)) score, err := GetDefenderScore(ip) assert.NoError(t, err) assert.Equal(t, 0, score) _, err = GetDefenderHost(ip) assert.Error(t, err) hosts, err := GetDefenderHosts() assert.NoError(t, err) assert.Nil(t, hosts) Config.DefenderConfig = DefenderConfig{ Enabled: true, Driver: DefenderDriverProvider, BanTime: 10, BanTimeIncrement: 50, Threshold: 0, ScoreInvalid: 2, ScoreValid: 1, ScoreNoAuth: 2, ObservationTime: 15, EntriesSoftLimit: 100, EntriesHardLimit: 150, LoginDelay: LoginDelay{ PasswordFailed: 200, }, } err = Initialize(Config, 0) // ScoreInvalid cannot be greater than threshold assert.Error(t, err) Config.DefenderConfig.Driver = "unsupported" err = Initialize(Config, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported defender driver") } Config.DefenderConfig.Driver = DefenderDriverMemory err = Initialize(Config, 0) // ScoreInvalid cannot be greater than threshold assert.Error(t, err) Config.DefenderConfig.Threshold = 3 err = Initialize(Config, 0) assert.NoError(t, err) assert.Nil(t, Reload()) AddDefenderEvent(ip, ProtocolSSH, HostEventNoLoginTried) assert.False(t, IsBanned(ip, ProtocolSSH)) score, err = GetDefenderScore(ip) assert.NoError(t, err) assert.Equal(t, 2, score) entry, err := GetDefenderHost(ip) assert.NoError(t, err) asJSON, err := json.Marshal(&entry) assert.NoError(t, err) assert.Equal(t, `{"id":"3132372e312e312e31","ip":"127.1.1.1","score":2}`, string(asJSON), "entry %v", entry) assert.True(t, DeleteDefenderHost(ip)) banTime, err = GetDefenderBanTime(ip) assert.NoError(t, err) assert.Nil(t, banTime) AddDefenderEvent(ip, ProtocolHTTP, HostEventLoginFailed) AddDefenderEvent(ip, ProtocolHTTP, HostEventNoLoginTried) assert.True(t, IsBanned(ip, ProtocolHTTP)) score, err = GetDefenderScore(ip) assert.NoError(t, err) assert.Equal(t, 0, score) banTime, err = GetDefenderBanTime(ip) assert.NoError(t, err) assert.NotNil(t, banTime) hosts, err = GetDefenderHosts() assert.NoError(t, err) assert.Len(t, hosts, 1) entry, err = GetDefenderHost(ip) assert.NoError(t, err) assert.False(t, entry.BanTime.IsZero()) assert.True(t, DeleteDefenderHost(ip)) hosts, err = GetDefenderHosts() assert.NoError(t, err) assert.Len(t, hosts, 0) banTime, err = GetDefenderBanTime(ip) assert.NoError(t, err) assert.Nil(t, banTime) assert.False(t, DeleteDefenderHost(ip)) startTime := time.Now() DelayLogin(nil) elapsed := time.Since(startTime) assert.Less(t, elapsed, time.Millisecond*50) startTime = time.Now() DelayLogin(ErrInternalFailure) elapsed = time.Since(startTime) assert.Greater(t, elapsed, time.Millisecond*150) Config = configCopy } func TestRateLimitersIntegration(t *testing.T) { configCopy := Config enabled, protocols := Config.GetRateLimitersStatus() assert.False(t, enabled) assert.Len(t, protocols, 0) entries := []dataprovider.IPListEntry{ { IPOrNet: "172.16.24.7/32", Type: dataprovider.IPListTypeRateLimiterSafeList, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "172.16.0.0/16", Type: dataprovider.IPListTypeRateLimiterSafeList, Mode: dataprovider.ListModeAllow, }, } for idx := range entries { e := entries[idx] err := dataprovider.AddIPListEntry(&e, "", "", "") assert.NoError(t, err) } Config.RateLimitersConfig = []RateLimiterConfig{ { Average: 100, Period: 10, Burst: 5, Type: int(rateLimiterTypeGlobal), Protocols: rateLimiterProtocolValues, }, { Average: 1, Period: 1000, Burst: 1, Type: int(rateLimiterTypeSource), Protocols: []string{ProtocolWebDAV, ProtocolWebDAV, ProtocolFTP}, GenerateDefenderEvents: true, EntriesSoftLimit: 100, EntriesHardLimit: 150, }, } err := Initialize(Config, 0) assert.Error(t, err) Config.RateLimitersConfig[0].Period = 1000 err = Initialize(Config, 0) assert.NoError(t, err) assert.NotNil(t, Config.rateLimitersList) assert.Len(t, rateLimiters, 4) assert.Len(t, rateLimiters[ProtocolSSH], 1) assert.Len(t, rateLimiters[ProtocolFTP], 2) assert.Len(t, rateLimiters[ProtocolWebDAV], 2) assert.Len(t, rateLimiters[ProtocolHTTP], 1) enabled, protocols = Config.GetRateLimitersStatus() assert.True(t, enabled) assert.Len(t, protocols, 4) assert.Contains(t, protocols, ProtocolFTP) assert.Contains(t, protocols, ProtocolSSH) assert.Contains(t, protocols, ProtocolHTTP) assert.Contains(t, protocols, ProtocolWebDAV) source1 := "127.1.1.1" source2 := "127.1.1.2" source3 := "172.16.24.7" // in safelist _, err = LimitRate(ProtocolSSH, source1) assert.NoError(t, err) _, err = LimitRate(ProtocolFTP, source1) assert.NoError(t, err) // sleep to allow the add configured burst to the token. // This sleep is not enough to add the per-source burst time.Sleep(20 * time.Millisecond) _, err = LimitRate(ProtocolWebDAV, source2) assert.NoError(t, err) _, err = LimitRate(ProtocolFTP, source1) assert.Error(t, err) _, err = LimitRate(ProtocolWebDAV, source2) assert.Error(t, err) _, err = LimitRate(ProtocolSSH, source1) assert.NoError(t, err) _, err = LimitRate(ProtocolSSH, source2) assert.NoError(t, err) for i := 0; i < 10; i++ { _, err = LimitRate(ProtocolWebDAV, source3) assert.NoError(t, err) } for _, e := range entries { err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "") assert.NoError(t, err) } assert.Nil(t, configCopy.rateLimitersList) Config = configCopy } func TestUserMaxSessions(t *testing.T) { c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, MaxSessions: 1, }, }) fakeConn := &fakeConnection{ BaseConnection: c, } err := Connections.Add(fakeConn) assert.NoError(t, err) err = Connections.Add(fakeConn) assert.Error(t, err) err = Connections.Swap(fakeConn) assert.NoError(t, err) Connections.Remove(fakeConn.GetID()) Connections.Lock() Connections.removeUserConnection(userTestUsername) Connections.Unlock() assert.Len(t, Connections.GetStats(""), 0) } func TestMaxConnections(t *testing.T) { oldValue := Config.MaxTotalConnections perHost := Config.MaxPerHostConnections Config.MaxPerHostConnections = 0 ipAddr := "192.168.7.8" assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP)) assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername)) Config.MaxTotalConnections = 1 Config.MaxPerHostConnections = perHost assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolHTTP)) assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername)) isShuttingDown.Store(true) assert.ErrorIs(t, Connections.IsNewTransferAllowed(userTestUsername), ErrShuttingDown) isShuttingDown.Store(false) c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, } err := Connections.Add(fakeConn) assert.NoError(t, err) assert.Len(t, Connections.GetStats(""), 1) assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) Connections.transfers.add(userTestUsername) assert.Error(t, Connections.IsNewTransferAllowed(userTestUsername)) Connections.transfers.remove(userTestUsername) assert.Equal(t, int32(0), Connections.GetTotalTransfers()) res := Connections.Close(fakeConn.GetID(), "") assert.True(t, res) assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond) assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) Connections.AddClientConnection(ipAddr) Connections.AddClientConnection(ipAddr) assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) Connections.RemoveClientConnection(ipAddr) assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolWebDAV)) Connections.transfers.add(userTestUsername) assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) Connections.transfers.remove(userTestUsername) Connections.RemoveClientConnection(ipAddr) Config.MaxTotalConnections = oldValue } func TestConnectionRoles(t *testing.T) { username := "testUsername" role1 := "testRole1" role2 := "testRole2" c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Role: role1, }, }) fakeConn := &fakeConnection{ BaseConnection: c, } err := Connections.Add(fakeConn) assert.NoError(t, err) assert.Len(t, Connections.GetStats(""), 1) assert.Len(t, Connections.GetStats(role1), 1) assert.Len(t, Connections.GetStats(role2), 0) res := Connections.Close(fakeConn.GetID(), role2) assert.False(t, res) assert.Len(t, Connections.GetStats(""), 1) res = Connections.Close(fakeConn.GetID(), role1) assert.True(t, res) assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond) } func TestMaxConnectionPerHost(t *testing.T) { defender, err := newInMemoryDefender(&DefenderConfig{ Enabled: true, Driver: DefenderDriverMemory, BanTime: 30, BanTimeIncrement: 50, Threshold: 15, ScoreInvalid: 2, ScoreValid: 1, ScoreLimitExceeded: 3, ObservationTime: 30, EntriesSoftLimit: 100, EntriesHardLimit: 150, }) require.NoError(t, err) oldMaxPerHostConn := Config.MaxPerHostConnections oldDefender := Config.defender Config.MaxPerHostConnections = 2 Config.defender = defender ipAddr := "192.168.9.9" Connections.AddClientConnection(ipAddr) assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) Connections.AddClientConnection(ipAddr) assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolWebDAV)) Connections.AddClientConnection(ipAddr) assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP)) assert.Equal(t, int32(3), Connections.GetClientConnections()) // Add the IP to the defender safe list entry := dataprovider.IPListEntry{ IPOrNet: ipAddr, Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, } err = dataprovider.AddIPListEntry(&entry, "", "", "") assert.NoError(t, err) Connections.AddClientConnection(ipAddr) assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) err = dataprovider.DeleteIPListEntry(entry.IPOrNet, dataprovider.IPListTypeDefender, "", "", "") assert.NoError(t, err) Connections.RemoveClientConnection(ipAddr) Connections.RemoveClientConnection(ipAddr) Connections.RemoveClientConnection(ipAddr) Connections.RemoveClientConnection(ipAddr) assert.Equal(t, int32(0), Connections.GetClientConnections()) Config.MaxPerHostConnections = oldMaxPerHostConn Config.defender = oldDefender } func TestIdleConnections(t *testing.T) { configCopy := Config Config.IdleTimeout = 1 err := Initialize(Config, 0) assert.NoError(t, err) conn1, conn2 := net.Pipe() customConn1 := &customNetConn{ Conn: conn1, id: "id1", } customConn2 := &customNetConn{ Conn: conn2, id: "id2", } sshConn1 := NewSSHConnection(customConn1.id, customConn1) sshConn2 := NewSSHConnection(customConn2.id, customConn2) username := "test_user" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Status: 1, }, } c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user) c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) fakeConn := &fakeConnection{ BaseConnection: c, } // both ssh connections are expired but they should get removed only // if there is no associated connection sshConn1.lastActivity.Store(c.lastActivity.Load()) sshConn2.lastActivity.Store(c.lastActivity.Load()) Connections.AddSSHConnection(sshConn1) err = Connections.Add(fakeConn) assert.NoError(t, err) assert.Equal(t, Connections.GetActiveSessions(username), 1) c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, "", "", user) fakeConn = &fakeConnection{ BaseConnection: c, } Connections.AddSSHConnection(sshConn2) err = Connections.Add(fakeConn) assert.NoError(t, err) assert.Equal(t, Connections.GetActiveSessions(username), 2) cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{}) cFTP.lastActivity.Store(time.Now().UnixNano()) fakeConn = &fakeConnection{ BaseConnection: cFTP, } err = Connections.Add(fakeConn) assert.NoError(t, err) // the user is expired, this connection will be removed cDAV := NewBaseConnection("id3", ProtocolWebDAV, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username + "_2", Status: 1, ExpirationDate: util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)), }, }) cDAV.lastActivity.Store(time.Now().UnixNano()) fakeConn = &fakeConnection{ BaseConnection: cDAV, } err = Connections.Add(fakeConn) assert.NoError(t, err) assert.Equal(t, 2, Connections.GetActiveSessions(username)) assert.Len(t, Connections.GetStats(""), 4) Connections.RLock() assert.Len(t, Connections.sshConnections, 2) Connections.RUnlock() startPeriodicChecks(100*time.Millisecond, 0) assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 2*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { Connections.RLock() defer Connections.RUnlock() return len(Connections.sshConnections) == 1 }, 1*time.Second, 200*time.Millisecond) stopEventScheduler() assert.Len(t, Connections.GetStats(""), 2) c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) sshConn2.lastActivity.Store(c.lastActivity.Load()) startPeriodicChecks(100*time.Millisecond, 1) assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 2*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { Connections.RLock() defer Connections.RUnlock() return len(Connections.sshConnections) == 0 }, 1*time.Second, 200*time.Millisecond) assert.Equal(t, int32(0), Connections.GetClientConnections()) stopEventScheduler() assert.True(t, customConn1.isClosed) assert.True(t, customConn2.isClosed) Config = configCopy } func TestCloseConnection(t *testing.T) { c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, } assert.NoError(t, Connections.IsNewConnectionAllowed("127.0.0.1", ProtocolHTTP)) err := Connections.Add(fakeConn) assert.NoError(t, err) assert.Len(t, Connections.GetStats(""), 1) res := Connections.Close(fakeConn.GetID(), "") assert.True(t, res) assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond) res = Connections.Close(fakeConn.GetID(), "") assert.False(t, res) Connections.Remove(fakeConn.GetID()) } func TestSwapConnection(t *testing.T) { c := NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, } err := Connections.Add(fakeConn) assert.NoError(t, err) if assert.Len(t, Connections.GetStats(""), 1) { assert.Equal(t, "", Connections.GetStats("")[0].Username) } c = NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, MaxSessions: 1, }, }) fakeConn = &fakeConnection{ BaseConnection: c, } c1 := NewBaseConnection("id1", ProtocolFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, }, }) fakeConn1 := &fakeConnection{ BaseConnection: c1, } err = Connections.Add(fakeConn1) assert.NoError(t, err) err = Connections.Swap(fakeConn) assert.Error(t, err) Connections.Remove(fakeConn1.ID) err = Connections.Swap(fakeConn) assert.NoError(t, err) if assert.Len(t, Connections.GetStats(""), 1) { assert.Equal(t, userTestUsername, Connections.GetStats("")[0].Username) } res := Connections.Close(fakeConn.GetID(), "") assert.True(t, res) assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond) err = Connections.Swap(fakeConn) assert.Error(t, err) } func TestAtomicUpload(t *testing.T) { configCopy := Config Config.UploadMode = UploadModeStandard assert.False(t, Config.IsAtomicUploadEnabled()) Config.UploadMode = UploadModeAtomic assert.True(t, Config.IsAtomicUploadEnabled()) Config.UploadMode = UploadModeAtomicWithResume assert.True(t, Config.IsAtomicUploadEnabled()) Config = configCopy } func TestConnectionStatus(t *testing.T) { username := "test_user" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, }, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) c1 := NewBaseConnection("id1", ProtocolSFTP, "", "", user) fakeConn1 := &fakeConnection{ BaseConnection: c1, } t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) t1.BytesReceived.Store(123) t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) t2.BytesSent.Store(456) c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user) fakeConn2 := &fakeConnection{ BaseConnection: c2, command: "md5sum", } c3 := NewBaseConnection("id3", ProtocolWebDAV, "", "", user) fakeConn3 := &fakeConnection{ BaseConnection: c3, command: "PROPFIND", } t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) err := Connections.Add(fakeConn1) assert.NoError(t, err) err = Connections.Add(fakeConn2) assert.NoError(t, err) err = Connections.Add(fakeConn3) assert.NoError(t, err) stats := Connections.GetStats("") assert.Len(t, stats, 3) for _, stat := range stats { assert.Equal(t, stat.Username, username) switch stat.ConnectionID { case "SFTP_id1": assert.Len(t, stat.Transfers, 2) case "DAV_id3": assert.Len(t, stat.Transfers, 1) } } err = t1.Close() assert.NoError(t, err) err = t2.Close() assert.NoError(t, err) err = fakeConn3.SignalTransfersAbort() assert.NoError(t, err) assert.True(t, t3.AbortTransfer.Load()) err = t3.Close() assert.NoError(t, err) err = fakeConn3.SignalTransfersAbort() assert.Error(t, err) Connections.Remove(fakeConn1.GetID()) stats = Connections.GetStats("") assert.Len(t, stats, 2) assert.Equal(t, fakeConn3.GetID(), stats[0].ConnectionID) assert.Equal(t, fakeConn2.GetID(), stats[1].ConnectionID) Connections.Remove(fakeConn2.GetID()) stats = Connections.GetStats("") assert.Len(t, stats, 1) assert.Equal(t, fakeConn3.GetID(), stats[0].ConnectionID) Connections.Remove(fakeConn3.GetID()) stats = Connections.GetStats("") assert.Len(t, stats, 0) } func TestQuotaScans(t *testing.T) { username := "username" assert.True(t, QuotaScans.AddUserQuotaScan(username, "")) assert.False(t, QuotaScans.AddUserQuotaScan(username, "")) usersScans := QuotaScans.GetUsersQuotaScans("") if assert.Len(t, usersScans, 1) { assert.Equal(t, usersScans[0].Username, username) assert.Equal(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime) QuotaScans.UserScans[0].StartTime = 0 assert.NotEqual(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime) } assert.True(t, QuotaScans.RemoveUserQuotaScan(username)) assert.False(t, QuotaScans.RemoveUserQuotaScan(username)) assert.Len(t, QuotaScans.GetUsersQuotaScans(""), 0) assert.Len(t, usersScans, 1) folderName := "folder" assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName)) assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName)) if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) { assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].Name, folderName) } assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName)) assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName)) assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0) } func TestQuotaScansRole(t *testing.T) { username := "u" role1 := "r1" role2 := "r2" assert.True(t, QuotaScans.AddUserQuotaScan(username, role1)) assert.False(t, QuotaScans.AddUserQuotaScan(username, "")) usersScans := QuotaScans.GetUsersQuotaScans("") assert.Len(t, usersScans, 1) assert.Empty(t, usersScans[0].Role) usersScans = QuotaScans.GetUsersQuotaScans(role1) assert.Len(t, usersScans, 1) usersScans = QuotaScans.GetUsersQuotaScans(role2) assert.Len(t, usersScans, 0) assert.True(t, QuotaScans.RemoveUserQuotaScan(username)) assert.False(t, QuotaScans.RemoveUserQuotaScan(username)) assert.Len(t, QuotaScans.GetUsersQuotaScans(""), 0) } func TestProxyPolicy(t *testing.T) { addr := net.TCPAddr{} downstream := net.TCPAddr{IP: net.ParseIP("1.1.1.1")} p := getProxyPolicy(nil, nil, proxyproto.IGNORE) policy, err := p(proxyproto.ConnPolicyOptions{ Upstream: &addr, Downstream: &downstream, }) assert.ErrorIs(t, err, proxyproto.ErrInvalidUpstream) assert.Equal(t, proxyproto.REJECT, policy) ip1 := net.ParseIP("10.8.1.1") ip2 := net.ParseIP("10.8.1.2") ip3 := net.ParseIP("10.8.1.3") allowed, err := util.ParseAllowedIPAndRanges([]string{ip1.String()}) assert.NoError(t, err) skipped, err := util.ParseAllowedIPAndRanges([]string{ip2.String(), ip3.String()}) assert.NoError(t, err) p = getProxyPolicy(allowed, skipped, proxyproto.IGNORE) policy, err = p(proxyproto.ConnPolicyOptions{ Upstream: &net.TCPAddr{IP: ip1}, Downstream: &downstream, }) assert.NoError(t, err) assert.Equal(t, proxyproto.USE, policy) policy, err = p(proxyproto.ConnPolicyOptions{ Upstream: &net.TCPAddr{IP: ip2}, Downstream: &downstream, }) assert.NoError(t, err) assert.Equal(t, proxyproto.SKIP, policy) policy, err = p(proxyproto.ConnPolicyOptions{ Upstream: &net.TCPAddr{IP: ip3}, Downstream: &downstream, }) assert.NoError(t, err) assert.Equal(t, proxyproto.SKIP, policy) policy, err = p(proxyproto.ConnPolicyOptions{ Upstream: &net.TCPAddr{IP: net.ParseIP("10.8.1.4")}, Downstream: &downstream, }) assert.NoError(t, err) assert.Equal(t, proxyproto.IGNORE, policy) p = getProxyPolicy(allowed, skipped, proxyproto.REQUIRE) policy, err = p(proxyproto.ConnPolicyOptions{ Upstream: &net.TCPAddr{IP: ip1}, Downstream: &downstream, }) assert.NoError(t, err) assert.Equal(t, proxyproto.REQUIRE, policy) policy, err = p(proxyproto.ConnPolicyOptions{ Upstream: &net.TCPAddr{IP: ip2}, Downstream: &downstream, }) assert.NoError(t, err) assert.Equal(t, proxyproto.SKIP, policy) policy, err = p(proxyproto.ConnPolicyOptions{ Upstream: &net.TCPAddr{IP: ip3}, Downstream: &downstream, }) assert.NoError(t, err) assert.Equal(t, proxyproto.SKIP, policy) policy, err = p(proxyproto.ConnPolicyOptions{ Upstream: &net.TCPAddr{IP: net.ParseIP("10.8.1.5")}, Downstream: &downstream, }) assert.ErrorIs(t, err, proxyproto.ErrInvalidUpstream) assert.Equal(t, proxyproto.REJECT, policy) } func TestProxyProtocolVersion(t *testing.T) { c := Configuration{ ProxyProtocol: 0, } _, err := c.GetProxyListener(nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "proxy protocol not configured") } c.ProxyProtocol = 1 listener, err := c.GetProxyListener(nil) assert.NoError(t, err) proxyListener, ok := listener.(*proxyproto.Listener) require.True(t, ok) assert.NotNil(t, proxyListener.ConnPolicy) c.ProxyProtocol = 2 listener, err = c.GetProxyListener(nil) assert.NoError(t, err) proxyListener, ok = listener.(*proxyproto.Listener) require.True(t, ok) assert.NotNil(t, proxyListener.ConnPolicy) } func TestStartupHook(t *testing.T) { Config.StartupHook = "" assert.NoError(t, Config.ExecuteStartupHook()) Config.StartupHook = "http://foo\x7f.com/startup" assert.Error(t, Config.ExecuteStartupHook()) Config.StartupHook = "http://invalid:5678/" assert.Error(t, Config.ExecuteStartupHook()) Config.StartupHook = fmt.Sprintf("http://%v", httpAddr) assert.NoError(t, Config.ExecuteStartupHook()) Config.StartupHook = "invalidhook" assert.Error(t, Config.ExecuteStartupHook()) if runtime.GOOS != osWindows { hookCmd, err := exec.LookPath("true") assert.NoError(t, err) Config.StartupHook = hookCmd assert.NoError(t, Config.ExecuteStartupHook()) } Config.StartupHook = "" } func TestPostDisconnectHook(t *testing.T) { Config.PostDisconnectHook = "http://127.0.0.1/" remoteAddr := "127.0.0.1:80" Config.checkPostDisconnectHook(remoteAddr, ProtocolHTTP, "", "", time.Now()) Config.checkPostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) Config.PostDisconnectHook = "http://bar\x7f.com/" Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) Config.PostDisconnectHook = fmt.Sprintf("http://%v", httpAddr) Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) Config.PostDisconnectHook = "relativePath" Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) if runtime.GOOS == osWindows { Config.PostDisconnectHook = "C:\\a\\bad\\command" Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) } else { Config.PostDisconnectHook = "/invalid/path" Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) hookCmd, err := exec.LookPath("true") assert.NoError(t, err) Config.PostDisconnectHook = hookCmd Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) } Config.PostDisconnectHook = "" } func TestPostConnectHook(t *testing.T) { Config.PostConnectHook = "" ipAddr := "127.0.0.1" assert.NoError(t, Config.ExecutePostConnectHook(ipAddr, ProtocolFTP)) Config.PostConnectHook = "http://foo\x7f.com/" assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) Config.PostConnectHook = "http://invalid:1234/" assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) Config.PostConnectHook = fmt.Sprintf("http://%v/404", httpAddr) assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolFTP)) Config.PostConnectHook = fmt.Sprintf("http://%v", httpAddr) assert.NoError(t, Config.ExecutePostConnectHook(ipAddr, ProtocolFTP)) Config.PostConnectHook = "invalid" assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolFTP)) if runtime.GOOS == osWindows { Config.PostConnectHook = "C:\\bad\\command" assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) } else { Config.PostConnectHook = "/invalid/path" assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) hookCmd, err := exec.LookPath("true") assert.NoError(t, err) Config.PostConnectHook = hookCmd assert.NoError(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) } Config.PostConnectHook = "" } func TestCryptoConvertFileInfo(t *testing.T) { name := "name" fs, err := vfs.NewCryptFs("connID1", os.TempDir(), "", vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret("secret"), }) require.NoError(t, err) cryptFs := fs.(*vfs.CryptFs) info := vfs.NewFileInfo(name, true, 48, time.Now(), false) assert.Equal(t, info, cryptFs.ConvertFileInfo(info)) info = vfs.NewFileInfo(name, false, 48, time.Now(), false) assert.NotEqual(t, info.Size(), cryptFs.ConvertFileInfo(info).Size()) info = vfs.NewFileInfo(name, false, 33, time.Now(), false) assert.Equal(t, int64(0), cryptFs.ConvertFileInfo(info).Size()) info = vfs.NewFileInfo(name, false, 1, time.Now(), false) assert.Equal(t, int64(0), cryptFs.ConvertFileInfo(info).Size()) } func TestFolderCopy(t *testing.T) { folder := vfs.BaseVirtualFolder{ ID: 1, Name: "name", MappedPath: filepath.Clean(os.TempDir()), UsedQuotaSize: 4096, UsedQuotaFiles: 2, LastQuotaUpdate: util.GetTimeAsMsSinceEpoch(time.Now()), Users: []string{"user1", "user2"}, } folderCopy := folder.GetACopy() folder.ID = 2 folder.Users = []string{"user3"} require.Len(t, folderCopy.Users, 2) require.True(t, slices.Contains(folderCopy.Users, "user1")) require.True(t, slices.Contains(folderCopy.Users, "user2")) require.Equal(t, int64(1), folderCopy.ID) require.Equal(t, folder.Name, folderCopy.Name) require.Equal(t, folder.MappedPath, folderCopy.MappedPath) require.Equal(t, folder.UsedQuotaSize, folderCopy.UsedQuotaSize) require.Equal(t, folder.UsedQuotaFiles, folderCopy.UsedQuotaFiles) require.Equal(t, folder.LastQuotaUpdate, folderCopy.LastQuotaUpdate) folder.FsConfig = vfs.Filesystem{ CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret("crypto secret"), }, } folderCopy = folder.GetACopy() folder.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret() require.Len(t, folderCopy.Users, 1) require.True(t, slices.Contains(folderCopy.Users, "user3")) require.Equal(t, int64(2), folderCopy.ID) require.Equal(t, folder.Name, folderCopy.Name) require.Equal(t, folder.MappedPath, folderCopy.MappedPath) require.Equal(t, folder.UsedQuotaSize, folderCopy.UsedQuotaSize) require.Equal(t, folder.UsedQuotaFiles, folderCopy.UsedQuotaFiles) require.Equal(t, folder.LastQuotaUpdate, folderCopy.LastQuotaUpdate) require.Equal(t, "crypto secret", folderCopy.FsConfig.CryptConfig.Passphrase.GetPayload()) } func TestCachedFs(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Clean(os.TempDir()), }, } conn := NewBaseConnection("id", ProtocolSFTP, "", "", user) // changing the user should not affect the connection user.HomeDir = filepath.Join(os.TempDir(), "temp") err := os.Mkdir(user.HomeDir, os.ModePerm) assert.NoError(t, err) fs, err := user.GetFilesystem("") assert.NoError(t, err) p, err := fs.ResolvePath("/") assert.NoError(t, err) assert.Equal(t, user.GetHomeDir(), p) _, p, err = conn.GetFsAndResolvedPath("/") assert.NoError(t, err) assert.Equal(t, filepath.Clean(os.TempDir()), p) // the filesystem is cached changing the provider will not affect the connection conn.User.FsConfig.Provider = sdk.S3FilesystemProvider _, p, err = conn.GetFsAndResolvedPath("/") assert.NoError(t, err) assert.Equal(t, filepath.Clean(os.TempDir()), p) user = dataprovider.User{} user.HomeDir = filepath.Join(os.TempDir(), "temp") user.FsConfig.Provider = sdk.S3FilesystemProvider _, err = user.GetFilesystem("") assert.Error(t, err) err = os.Remove(user.HomeDir) assert.NoError(t, err) } func TestParseAllowedIPAndRanges(t *testing.T) { _, err := util.ParseAllowedIPAndRanges([]string{"1.1.1.1", "not an ip"}) assert.Error(t, err) _, err = util.ParseAllowedIPAndRanges([]string{"1.1.1.5", "192.168.1.0/240"}) assert.Error(t, err) allow, err := util.ParseAllowedIPAndRanges([]string{"192.168.1.2", "172.16.0.0/24"}) assert.NoError(t, err) assert.True(t, allow[0](net.ParseIP("192.168.1.2"))) assert.False(t, allow[0](net.ParseIP("192.168.2.2"))) assert.True(t, allow[1](net.ParseIP("172.16.0.1"))) assert.False(t, allow[1](net.ParseIP("172.16.1.1"))) } func TestHideConfidentialData(_ *testing.T) { for _, provider := range []sdk.FilesystemProvider{sdk.LocalFilesystemProvider, sdk.CryptedFilesystemProvider, sdk.S3FilesystemProvider, sdk.GCSFilesystemProvider, sdk.AzureBlobFilesystemProvider, sdk.SFTPFilesystemProvider, } { u := dataprovider.User{ FsConfig: vfs.Filesystem{ Provider: provider, }, } u.PrepareForRendering() f := vfs.BaseVirtualFolder{ FsConfig: vfs.Filesystem{ Provider: provider, }, } f.PrepareForRendering() } a := dataprovider.Admin{} a.HideConfidentialData() } func TestUserPerms(t *testing.T) { u := dataprovider.User{} u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermDelete} assert.True(t, u.HasAnyPerm([]string{dataprovider.PermRename, dataprovider.PermDelete}, "/")) assert.False(t, u.HasAnyPerm([]string{dataprovider.PermRename, dataprovider.PermCreateDirs}, "/")) u.Permissions["/"] = []string{dataprovider.PermDelete, dataprovider.PermCreateDirs} assert.True(t, u.HasPermsDeleteAll("/")) assert.False(t, u.HasPermsRenameAll("/")) u.Permissions["/"] = []string{dataprovider.PermDeleteDirs, dataprovider.PermDeleteFiles, dataprovider.PermRenameDirs} assert.True(t, u.HasPermsDeleteAll("/")) assert.False(t, u.HasPermsRenameAll("/")) u.Permissions["/"] = []string{dataprovider.PermDeleteDirs, dataprovider.PermRenameFiles, dataprovider.PermRenameDirs} assert.False(t, u.HasPermsDeleteAll("/")) assert.True(t, u.HasPermsRenameAll("/")) } func TestGetTLSVersion(t *testing.T) { tlsVer := util.GetTLSVersion(0) assert.Equal(t, uint16(tls.VersionTLS12), tlsVer) tlsVer = util.GetTLSVersion(12) assert.Equal(t, uint16(tls.VersionTLS12), tlsVer) tlsVer = util.GetTLSVersion(2) assert.Equal(t, uint16(tls.VersionTLS12), tlsVer) tlsVer = util.GetTLSVersion(13) assert.Equal(t, uint16(tls.VersionTLS13), tlsVer) } func TestCleanPath(t *testing.T) { assert.Equal(t, "/", util.CleanPath("/")) assert.Equal(t, "/", util.CleanPath(".")) assert.Equal(t, "/", util.CleanPath("")) assert.Equal(t, "/", util.CleanPath("/.")) assert.Equal(t, "/", util.CleanPath("/a/..")) assert.Equal(t, "/a", util.CleanPath("/a/")) assert.Equal(t, "/a", util.CleanPath("a/")) // filepath.ToSlash does not touch \ as char on unix systems // so os.PathSeparator is used for windows compatible tests bslash := string(os.PathSeparator) assert.Equal(t, "/", util.CleanPath(bslash)) assert.Equal(t, "/", util.CleanPath(bslash+bslash)) assert.Equal(t, "/a", util.CleanPath(bslash+"a"+bslash)) assert.Equal(t, "/a", util.CleanPath("a"+bslash)) assert.Equal(t, "/a/b/c", util.CleanPath(bslash+"a"+bslash+bslash+"b"+bslash+bslash+"c"+bslash)) assert.Equal(t, "/C:/a", util.CleanPath("C:"+bslash+"a")) } func TestUserRecentActivity(t *testing.T) { u := dataprovider.User{} res := u.HasRecentActivity() assert.False(t, res) u.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) res = u.HasRecentActivity() assert.True(t, res) u.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)) res = u.HasRecentActivity() assert.False(t, res) u.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Second)) res = u.HasRecentActivity() assert.True(t, res) } func TestVfsSameResource(t *testing.T) { fs := vfs.Filesystem{} other := vfs.Filesystem{} res := fs.IsSameResource(other) assert.True(t, res) fs = vfs.Filesystem{ Provider: sdk.S3FilesystemProvider, S3Config: vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: "a", Region: "b", }, }, } other = vfs.Filesystem{ Provider: sdk.S3FilesystemProvider, S3Config: vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: "a", Region: "c", }, }, } res = fs.IsSameResource(other) assert.False(t, res) other = vfs.Filesystem{ Provider: sdk.S3FilesystemProvider, S3Config: vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: "a", Region: "b", }, }, } res = fs.IsSameResource(other) assert.True(t, res) fs = vfs.Filesystem{ Provider: sdk.GCSFilesystemProvider, GCSConfig: vfs.GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ Bucket: "b", }, }, } other = vfs.Filesystem{ Provider: sdk.GCSFilesystemProvider, GCSConfig: vfs.GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ Bucket: "c", }, }, } res = fs.IsSameResource(other) assert.False(t, res) other = vfs.Filesystem{ Provider: sdk.GCSFilesystemProvider, GCSConfig: vfs.GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ Bucket: "b", }, }, } res = fs.IsSameResource(other) assert.True(t, res) sasURL := kms.NewPlainSecret("http://127.0.0.1/sasurl") fs = vfs.Filesystem{ Provider: sdk.AzureBlobFilesystemProvider, AzBlobConfig: vfs.AzBlobFsConfig{ BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ AccountName: "a", }, SASURL: sasURL, }, } err := fs.Validate("data1") assert.NoError(t, err) other = vfs.Filesystem{ Provider: sdk.AzureBlobFilesystemProvider, AzBlobConfig: vfs.AzBlobFsConfig{ BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ AccountName: "a", }, SASURL: sasURL, }, } err = other.Validate("data2") assert.NoError(t, err) err = fs.AzBlobConfig.SASURL.TryDecrypt() assert.NoError(t, err) err = other.AzBlobConfig.SASURL.TryDecrypt() assert.NoError(t, err) res = fs.IsSameResource(other) assert.True(t, res) fs.AzBlobConfig.AccountName = "b" res = fs.IsSameResource(other) assert.False(t, res) fs.AzBlobConfig.AccountName = "a" other.AzBlobConfig.SASURL = kms.NewPlainSecret("http://127.1.1.1/sasurl") err = other.Validate("data2") assert.NoError(t, err) err = other.AzBlobConfig.SASURL.TryDecrypt() assert.NoError(t, err) res = fs.IsSameResource(other) assert.False(t, res) fs = vfs.Filesystem{ Provider: sdk.HTTPFilesystemProvider, HTTPConfig: vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: "http://127.0.0.1/httpfs", Username: "a", }, }, } other = vfs.Filesystem{ Provider: sdk.HTTPFilesystemProvider, HTTPConfig: vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: "http://127.0.0.1/httpfs", Username: "b", }, }, } res = fs.IsSameResource(other) assert.True(t, res) fs.HTTPConfig.EqualityCheckMode = 1 res = fs.IsSameResource(other) assert.False(t, res) } func TestUpdateTransferTimestamps(t *testing.T) { username := "user_test_timestamps" user := &dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, HomeDir: filepath.Join(os.TempDir(), username), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } err := dataprovider.AddUser(user, "", "", "") assert.NoError(t, err) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, int64(0), user.FirstDownload) err = dataprovider.UpdateUserTransferTimestamps(username, true) assert.NoError(t, err) userGet, err := dataprovider.UserExists(username, "") assert.NoError(t, err) assert.Greater(t, userGet.FirstUpload, int64(0)) assert.Equal(t, int64(0), user.FirstDownload) err = dataprovider.UpdateUserTransferTimestamps(username, false) assert.NoError(t, err) userGet, err = dataprovider.UserExists(username, "") assert.NoError(t, err) assert.Greater(t, userGet.FirstUpload, int64(0)) assert.Greater(t, userGet.FirstDownload, int64(0)) // updating again must fail err = dataprovider.UpdateUserTransferTimestamps(username, true) assert.Error(t, err) err = dataprovider.UpdateUserTransferTimestamps(username, false) assert.Error(t, err) // cleanup err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) } func TestIPList(t *testing.T) { type test struct { ip string protocol string expectedMatch bool expectedMode int expectedErr bool } entries := []dataprovider.IPListEntry{ { IPOrNet: "192.168.0.0/25", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "192.168.0.128/25", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, Protocols: 3, }, { IPOrNet: "192.168.2.128/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, Protocols: 5, }, { IPOrNet: "::/0", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, Protocols: 4, }, { IPOrNet: "2001:4860:4860::8888/120", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, Protocols: 1, }, { IPOrNet: "2001:4860:4860::8988/120", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, Protocols: 3, }, { IPOrNet: "::1/128", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, Protocols: 0, }, } ipList, err := dataprovider.NewIPList(dataprovider.IPListTypeDefender) require.NoError(t, err) for idx := range entries { e := entries[idx] err := dataprovider.AddIPListEntry(&e, "", "", "") assert.NoError(t, err) } tests := []test{ {ip: "1.1.1.1", protocol: ProtocolSSH, expectedMatch: false, expectedMode: 0, expectedErr: false}, {ip: "invalid ip", protocol: ProtocolSSH, expectedMatch: false, expectedMode: 0, expectedErr: true}, {ip: "192.168.0.1", protocol: ProtocolFTP, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "192.168.0.2", protocol: ProtocolHTTP, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "192.168.0.3", protocol: ProtocolWebDAV, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "192.168.0.4", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "192.168.0.156", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeDeny, expectedErr: false}, {ip: "192.168.0.158", protocol: ProtocolFTP, expectedMatch: true, expectedMode: dataprovider.ListModeDeny, expectedErr: false}, {ip: "192.168.0.158", protocol: ProtocolHTTP, expectedMatch: false, expectedMode: 0, expectedErr: false}, {ip: "192.168.2.128", protocol: ProtocolHTTP, expectedMatch: false, expectedMode: 0, expectedErr: false}, {ip: "192.168.2.128", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "::2", protocol: ProtocolSSH, expectedMatch: false, expectedMode: 0, expectedErr: false}, {ip: "::2", protocol: ProtocolWebDAV, expectedMatch: true, expectedMode: dataprovider.ListModeDeny, expectedErr: false}, {ip: "::1", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "::1", protocol: ProtocolHTTP, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "2001:4860:4860:0000:0000:0000:0000:8889", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeDeny, expectedErr: false}, {ip: "2001:4860:4860:0000:0000:0000:0000:8889", protocol: ProtocolFTP, expectedMatch: false, expectedMode: 0, expectedErr: false}, {ip: "2001:4860:4860:0000:0000:0000:0000:8989", protocol: ProtocolFTP, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "2001:4860:4860:0000:0000:0000:0000:89F1", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, {ip: "2001:4860:4860:0000:0000:0000:0000:89F1", protocol: ProtocolHTTP, expectedMatch: false, expectedMode: 0, expectedErr: false}, } for _, tc := range tests { match, mode, err := ipList.IsListed(tc.ip, tc.protocol) if tc.expectedErr { assert.Error(t, err, "ip %s, protocol %s", tc.ip, tc.protocol) } else { assert.NoError(t, err, "ip %s, protocol %s", tc.ip, tc.protocol) } assert.Equal(t, tc.expectedMatch, match, "ip %s, protocol %s", tc.ip, tc.protocol) assert.Equal(t, tc.expectedMode, mode, "ip %s, protocol %s", tc.ip, tc.protocol) } ipList.DisableMemoryMode() for _, tc := range tests { match, mode, err := ipList.IsListed(tc.ip, tc.protocol) if tc.expectedErr { assert.Error(t, err, "ip %s, protocol %s", tc.ip, tc.protocol) } else { assert.NoError(t, err, "ip %s, protocol %s", tc.ip, tc.protocol) } assert.Equal(t, tc.expectedMatch, match, "ip %s, protocol %s", tc.ip, tc.protocol) assert.Equal(t, tc.expectedMode, mode, "ip %s, protocol %s", tc.ip, tc.protocol) } for _, e := range entries { err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "") assert.NoError(t, err) } } func TestSQLPlaceholderLimits(t *testing.T) { numGroups := 120 numUsers := 120 var groupMapping []sdk.GroupMapping folder := vfs.BaseVirtualFolder{ Name: "testfolder", MappedPath: filepath.Join(os.TempDir(), "folder"), } err := dataprovider.AddFolder(&folder, "", "", "") assert.NoError(t, err) for i := 0; i < numGroups; i++ { group := dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: fmt.Sprintf("testgroup%d", i), }, UserSettings: dataprovider.GroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{ Permissions: map[string][]string{ fmt.Sprintf("/dir%d", i): {dataprovider.PermAny}, }, }, }, } group.VirtualFolders = append(group.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: folder, VirtualPath: "/vdir", }) err := dataprovider.AddGroup(&group, "", "", "") assert.NoError(t, err) groupMapping = append(groupMapping, sdk.GroupMapping{ Name: group.Name, Type: sdk.GroupTypeSecondary, }) } user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "testusername", HomeDir: filepath.Join(os.TempDir(), "testhome"), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, Groups: groupMapping, } err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) users, err := dataprovider.GetUsersForQuotaCheck(map[string]bool{user.Username: true}) assert.NoError(t, err) if assert.Len(t, users, 1) { for i := 0; i < numGroups; i++ { _, ok := users[0].Permissions[fmt.Sprintf("/dir%d", i)] assert.True(t, ok) } } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) for i := 0; i < numUsers; i++ { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: fmt.Sprintf("testusername%d", i), HomeDir: filepath.Join(os.TempDir()), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, Groups: []sdk.GroupMapping{ { Name: "testgroup0", Type: sdk.GroupTypePrimary, }, }, } err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) } time.Sleep(100 * time.Millisecond) err = dataprovider.DeleteFolder(folder.Name, "", "", "") assert.NoError(t, err) for i := 0; i < numUsers; i++ { username := fmt.Sprintf("testusername%d", i) user, err := dataprovider.UserExists(username, "") assert.NoError(t, err) assert.Greater(t, user.UpdatedAt, user.CreatedAt) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) } for i := 0; i < numGroups; i++ { groupName := fmt.Sprintf("testgroup%d", i) err = dataprovider.DeleteGroup(groupName, "", "", "") assert.NoError(t, err) } } func TestALPNProtocols(t *testing.T) { protocols := util.GetALPNProtocols(nil) assert.Equal(t, []string{"http/1.1", "h2"}, protocols) protocols = util.GetALPNProtocols([]string{"invalid1", "invalid2"}) assert.Equal(t, []string{"http/1.1", "h2"}, protocols) protocols = util.GetALPNProtocols([]string{"invalid1", "h2", "invalid2"}) assert.Equal(t, []string{"h2"}, protocols) protocols = util.GetALPNProtocols([]string{"h2", "http/1.1"}) assert.Equal(t, []string{"h2", "http/1.1"}, protocols) } func TestServerVersion(t *testing.T) { appName := "SFTPGo" version.SetConfig("") v := version.GetServerVersion("_", false) assert.Equal(t, fmt.Sprintf("%s_%s", appName, version.Get().Version), v) v = version.GetServerVersion("-", true) assert.Equal(t, fmt.Sprintf("%s-%s-", appName, version.Get().Version), v) version.SetConfig("short") v = version.GetServerVersion("_", false) assert.Equal(t, appName, v) v = version.GetServerVersion("_", true) assert.Equal(t, appName+"_", v) version.SetConfig("") } func BenchmarkBcryptHashing(b *testing.B) { bcryptPassword := "bcryptpassword" for i := 0; i < b.N; i++ { _, err := bcrypt.GenerateFromPassword([]byte(bcryptPassword), 10) if err != nil { panic(err) } } } func BenchmarkCompareBcryptPassword(b *testing.B) { bcryptPassword := "$2a$10$lPDdnDimJZ7d5/GwL6xDuOqoZVRXok6OHHhivCnanWUtcgN0Zafki" for i := 0; i < b.N; i++ { err := bcrypt.CompareHashAndPassword([]byte(bcryptPassword), []byte("password")) if err != nil { panic(err) } } } func BenchmarkArgon2Hashing(b *testing.B) { argonPassword := "argon2password" for i := 0; i < b.N; i++ { _, err := argon2id.CreateHash(argonPassword, argon2id.DefaultParams) if err != nil { panic(err) } } } func BenchmarkCompareArgon2Password(b *testing.B) { argon2Password := "$argon2id$v=19$m=65536,t=1,p=2$aOoAOdAwvzhOgi7wUFjXlw$wn/y37dBWdKHtPXHR03nNaKHWKPXyNuVXOknaU+YZ+s" for i := 0; i < b.N; i++ { _, err := argon2id.ComparePasswordAndHash("password", argon2Password) if err != nil { panic(err) } } } func BenchmarkAddRemoveConnections(b *testing.B) { var conns []ActiveConnection for i := 0; i < 100; i++ { conns = append(conns, &fakeConnection{ BaseConnection: NewBaseConnection(fmt.Sprintf("id%d", i), ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, }, }), }) } b.ResetTimer() for i := 0; i < b.N; i++ { for _, c := range conns { if err := Connections.Add(c); err != nil { panic(err) } } var wg sync.WaitGroup for idx := len(conns) - 1; idx >= 0; idx-- { wg.Add(1) go func(index int) { defer wg.Done() Connections.Remove(conns[index].GetID()) }(idx) } wg.Wait() } } func BenchmarkAddRemoveSSHConnections(b *testing.B) { conn1, conn2 := net.Pipe() var conns []*SSHConnection for i := 0; i < 2000; i++ { conns = append(conns, NewSSHConnection(fmt.Sprintf("id%d", i), conn1)) } b.ResetTimer() for i := 0; i < b.N; i++ { for _, c := range conns { Connections.AddSSHConnection(c) } for idx := len(conns) - 1; idx >= 0; idx-- { Connections.RemoveSSHConnection(conns[idx].GetID()) } } conn1.Close() conn2.Close() } ================================================ FILE: internal/common/connection.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "errors" "fmt" "io" "io/fs" "os" "path" "slices" "strings" "sync" "sync/atomic" "time" ftpserver "github.com/fclairamb/ftpserverlib" "github.com/pkg/sftp" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) // BaseConnection defines common fields for a connection using any supported protocol type BaseConnection struct { // last activity for this connection. // Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment lastActivity atomic.Int64 uploadDone atomic.Bool downloadDone atomic.Bool // unique ID for a transfer. // This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment transferID atomic.Int64 // Unique identifier for the connection ID string // user associated with this connection if any User dataprovider.User // start time for this connection startTime time.Time protocol string remoteAddr string localAddr string sync.RWMutex activeTransfers []ActiveTransfer } // NewBaseConnection returns a new BaseConnection func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprovider.User) *BaseConnection { connID := id if slices.Contains(supportedProtocols, protocol) { connID = fmt.Sprintf("%s_%s", protocol, id) } user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID) c := &BaseConnection{ ID: connID, User: user, startTime: time.Now(), protocol: protocol, localAddr: localAddr, remoteAddr: remoteAddr, } c.transferID.Store(0) c.lastActivity.Store(time.Now().UnixNano()) return c } // Log outputs a log entry to the configured logger func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) { logger.Log(level, c.protocol, c.ID, format, v...) } // GetTransferID returns an unique transfer ID for this connection func (c *BaseConnection) GetTransferID() int64 { return c.transferID.Add(1) } // GetID returns the connection ID func (c *BaseConnection) GetID() string { return c.ID } // GetUsername returns the authenticated username associated with this connection if any func (c *BaseConnection) GetUsername() string { return c.User.Username } // GetRole returns the role for the user associated with this connection func (c *BaseConnection) GetRole() string { return c.User.Role } // GetMaxSessions returns the maximum number of concurrent sessions allowed func (c *BaseConnection) GetMaxSessions() int { return c.User.MaxSessions } // isAccessAllowed returns true if the user's access conditions are met func (c *BaseConnection) isAccessAllowed() bool { if err := c.User.CheckLoginConditions(); err != nil { return false } return true } // GetProtocol returns the protocol for the connection func (c *BaseConnection) GetProtocol() string { return c.protocol } // GetRemoteIP returns the remote ip address func (c *BaseConnection) GetRemoteIP() string { return util.GetIPFromRemoteAddress(c.remoteAddr) } // SetProtocol sets the protocol for this connection func (c *BaseConnection) SetProtocol(protocol string) { c.protocol = protocol if slices.Contains(supportedProtocols, c.protocol) { c.ID = fmt.Sprintf("%v_%v", c.protocol, c.ID) } } // GetConnectionTime returns the initial connection time func (c *BaseConnection) GetConnectionTime() time.Time { return c.startTime } // UpdateLastActivity updates last activity for this connection func (c *BaseConnection) UpdateLastActivity() { c.lastActivity.Store(time.Now().UnixNano()) } // GetLastActivity returns the last connection activity func (c *BaseConnection) GetLastActivity() time.Time { return time.Unix(0, c.lastActivity.Load()) } // CloseFS closes the underlying fs func (c *BaseConnection) CloseFS() error { return c.User.CloseFs() } // AddTransfer associates a new transfer to this connection func (c *BaseConnection) AddTransfer(t ActiveTransfer) { Connections.transfers.add(c.User.Username) c.Lock() defer c.Unlock() c.activeTransfers = append(c.activeTransfers, t) c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers)) if t.HasSizeLimit() { folderName := "" if t.GetType() == TransferUpload { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath())) if err == nil { if !vfolder.IsIncludedInUserQuota() { folderName = vfolder.Name } } } go transfersChecker.AddTransfer(dataprovider.ActiveTransfer{ ID: t.GetID(), Type: t.GetType(), ConnID: c.ID, Username: c.GetUsername(), FolderName: folderName, IP: c.GetRemoteIP(), TruncatedSize: t.GetTruncatedSize(), CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) } } // RemoveTransfer removes the specified transfer from the active ones func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) { Connections.transfers.remove(c.User.Username) c.Lock() defer c.Unlock() if t.HasSizeLimit() { go transfersChecker.RemoveTransfer(t.GetID(), c.ID) } for idx, transfer := range c.activeTransfers { if transfer.GetID() == t.GetID() { lastIdx := len(c.activeTransfers) - 1 c.activeTransfers[idx] = c.activeTransfers[lastIdx] c.activeTransfers[lastIdx] = nil c.activeTransfers = c.activeTransfers[:lastIdx] c.Log(logger.LevelDebug, "transfer removed, id: %v active transfers: %v", t.GetID(), len(c.activeTransfers)) return } } c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID()) } // SignalTransferClose makes the transfer fail on the next read/write with the // specified error func (c *BaseConnection) SignalTransferClose(transferID int64, err error) { c.RLock() defer c.RUnlock() for _, t := range c.activeTransfers { if t.GetID() == transferID { c.Log(logger.LevelInfo, "signal transfer close for transfer id %v", transferID) t.SignalClose(err) } } } // GetTransfers returns the active transfers func (c *BaseConnection) GetTransfers() []ConnectionTransfer { c.RLock() defer c.RUnlock() transfers := make([]ConnectionTransfer, 0, len(c.activeTransfers)) for _, t := range c.activeTransfers { var operationType string switch t.GetType() { case TransferDownload: operationType = operationDownload case TransferUpload: operationType = operationUpload } transfers = append(transfers, ConnectionTransfer{ ID: t.GetID(), OperationType: operationType, StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()), Size: t.GetSize(), VirtualPath: t.GetVirtualPath(), HasSizeLimit: t.HasSizeLimit(), ULSize: t.GetUploadedSize(), DLSize: t.GetDownloadedSize(), }) } return transfers } // SignalTransfersAbort signals to the active transfers to exit as soon as possible func (c *BaseConnection) SignalTransfersAbort() error { c.RLock() defer c.RUnlock() if len(c.activeTransfers) == 0 { return errors.New("no active transfer found") } for _, t := range c.activeTransfers { t.SignalClose(ErrTransferAborted) } return nil } func (c *BaseConnection) getRealFsPath(fsPath string) string { c.RLock() defer c.RUnlock() for _, t := range c.activeTransfers { if p := t.GetRealFsPath(fsPath); p != "" { return p } } return fsPath } func (c *BaseConnection) setTimes(fsPath string, atime time.Time, mtime time.Time) bool { c.RLock() defer c.RUnlock() for _, t := range c.activeTransfers { if t.SetTimes(fsPath, atime, mtime) { return true } } return false } // getInfoForOngoingUpload returns upload statistics for an upload currently in // progress on this connection. func (c *BaseConnection) getInfoForOngoingUpload(fsPath string) (os.FileInfo, error) { c.RLock() defer c.RUnlock() for _, t := range c.activeTransfers { if t.GetType() == TransferUpload && t.GetFsPath() == fsPath { return vfs.NewFileInfo(t.GetVirtualPath(), false, t.GetSize(), t.GetStartTime(), false), nil } } return nil, os.ErrNotExist } func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, error) { c.RLock() defer c.RUnlock() for _, t := range c.activeTransfers { initialSize, err := t.Truncate(fsPath, size) if err != errTransferMismatch { return initialSize, err } } return 0, errNoTransfer } // ListDir reads the directory matching virtualPath and returns a list of directory entries func (c *BaseConnection) ListDir(virtualPath string) (*DirListerAt, error) { if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) { return nil, c.GetPermissionDeniedError() } fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { return nil, err } lister, err := fs.ReadDir(fsPath) if err != nil { c.Log(logger.LevelDebug, "error listing directory: %+v", err) return nil, c.GetFsError(fs, err) } return &DirListerAt{ virtualPath: virtualPath, conn: c, fs: fs, info: c.User.GetVirtualFoldersInfo(virtualPath), lister: lister, }, nil } // CheckParentDirs tries to create the specified directory and any missing parent dirs func (c *BaseConnection) CheckParentDirs(virtualPath string) error { fs, err := c.User.GetFilesystemForPath(virtualPath, c.GetID()) if err != nil { return err } if fs.HasVirtualFolders() { return nil } if _, err := c.DoStat(virtualPath, 0, false); !c.IsNotExistError(err) { return err } dirs := util.GetDirsForVirtualPath(virtualPath) for idx := len(dirs) - 1; idx >= 0; idx-- { fs, err = c.User.GetFilesystemForPath(dirs[idx], c.GetID()) if err != nil { return err } if fs.HasVirtualFolders() { continue } if err = c.createDirIfMissing(dirs[idx]); err != nil { return fmt.Errorf("unable to check/create missing parent dir %q for virtual path %q: %w", dirs[idx], virtualPath, err) } } return nil } // GetCreateChecks returns the checks for creating new files func (c *BaseConnection) GetCreateChecks(virtualPath string, isNewFile bool, isResume bool) int { result := 0 if !isNewFile { if isResume { result += vfs.CheckResume } return result } if !c.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(virtualPath)) { result += vfs.CheckParentDir return result } return result } // CreateDir creates a new directory at the specified fsPath func (c *BaseConnection) CreateDir(virtualPath string, checkFilePatterns bool) error { if !c.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(virtualPath)) { return c.GetPermissionDeniedError() } if checkFilePatterns { if ok, _ := c.User.IsFileAllowed(virtualPath); !ok { return c.GetPermissionDeniedError() } } if c.User.IsVirtualFolder(virtualPath) { c.Log(logger.LevelWarn, "mkdir not allowed %q is a virtual folder", virtualPath) return c.GetPermissionDeniedError() } fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { return err } startTime := time.Now() if err := fs.Mkdir(fsPath); err != nil { c.Log(logger.LevelError, "error creating dir: %q error: %+v", fsPath, err) return c.GetFsError(fs, err) } vfs.SetPathPermissions(fs, fsPath, c.User.GetUID(), c.User.GetGID()) elapsed := time.Since(startTime).Nanoseconds() / 1000000 logger.CommandLog(mkdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) ExecuteActionNotification(c, operationMkdir, fsPath, virtualPath, "", "", "", 0, nil, elapsed, nil) //nolint:errcheck return nil } // IsRemoveFileAllowed returns an error if removing this file is not allowed func (c *BaseConnection) IsRemoveFileAllowed(virtualPath string) error { if !c.User.HasAnyPerm([]string{dataprovider.PermDeleteFiles, dataprovider.PermDelete}, path.Dir(virtualPath)) { return c.GetPermissionDeniedError() } if ok, policy := c.User.IsFileAllowed(virtualPath); !ok { c.Log(logger.LevelDebug, "removing file %q is not allowed", virtualPath) return c.GetErrorForDeniedFile(policy) } return nil } // RemoveFile removes a file at the specified fsPath func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo) error { if err := c.IsRemoveFileAllowed(virtualPath); err != nil { return err } size := info.Size() status, err := ExecutePreAction(c, operationPreDelete, fsPath, virtualPath, size, 0) if err != nil { c.Log(logger.LevelDebug, "delete for file %q denied by pre action: %v", virtualPath, err) return c.GetPermissionDeniedError() } updateQuota := true startTime := time.Now() if err := fs.Remove(fsPath, false); err != nil { if status > 0 && fs.IsNotExist(err) { // file removed in the pre-action, if the file was deleted from the EventManager the quota is already updated c.Log(logger.LevelDebug, "file deleted from the hook, status: %d", status) updateQuota = (status == 1) } else { c.Log(logger.LevelError, "failed to remove file/symlink %q: %+v", fsPath, err) return c.GetFsError(fs, err) } } elapsed := time.Since(startTime).Nanoseconds() / 1000000 logger.CommandLog(removeLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) if updateQuota && info.Mode()&os.ModeSymlink == 0 { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) if err == nil { dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, -1, -size, false) } else { dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck } } ExecuteActionNotification(c, operationDelete, fsPath, virtualPath, "", "", "", size, nil, elapsed, nil) //nolint:errcheck return nil } // IsRemoveDirAllowed returns an error if removing this directory is not allowed func (c *BaseConnection) IsRemoveDirAllowed(fs vfs.Fs, fsPath, virtualPath string) error { if virtualPath == "/" || fs.GetRelativePath(fsPath) == "/" { c.Log(logger.LevelWarn, "removing root dir is not allowed") return c.GetPermissionDeniedError() } if c.User.IsVirtualFolder(virtualPath) { c.Log(logger.LevelWarn, "removing a virtual folder is not allowed: %q", virtualPath) return fmt.Errorf("removing virtual folders is not allowed: %w", c.GetPermissionDeniedError()) } if c.User.HasVirtualFoldersInside(virtualPath) { c.Log(logger.LevelWarn, "removing a directory with a virtual folder inside is not allowed: %q", virtualPath) return fmt.Errorf("cannot remove directory %q with virtual folders inside: %w", virtualPath, c.GetOpUnsupportedError()) } if c.User.IsMappedPath(fsPath) { c.Log(logger.LevelWarn, "removing a directory mapped as virtual folder is not allowed: %q", fsPath) return fmt.Errorf("removing the directory %q mapped as virtual folder is not allowed: %w", virtualPath, c.GetPermissionDeniedError()) } if !c.User.HasAnyPerm([]string{dataprovider.PermDeleteDirs, dataprovider.PermDelete}, path.Dir(virtualPath)) { return c.GetPermissionDeniedError() } if ok, policy := c.User.IsFileAllowed(virtualPath); !ok { c.Log(logger.LevelDebug, "removing directory %q is not allowed", virtualPath) return c.GetErrorForDeniedFile(policy) } return nil } // RemoveDir removes a directory at the specified fsPath func (c *BaseConnection) RemoveDir(virtualPath string) error { fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { return err } if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil { return err } var fi os.FileInfo if fi, err = fs.Lstat(fsPath); err != nil { // see #149 if fs.IsNotExist(err) && fs.HasVirtualFolders() { return nil } c.Log(logger.LevelError, "failed to remove a dir %q: stat error: %+v", fsPath, err) return c.GetFsError(fs, err) } if !fi.IsDir() || fi.Mode()&os.ModeSymlink != 0 { c.Log(logger.LevelError, "cannot remove %q is not a directory", fsPath) return c.GetGenericError(nil) } startTime := time.Now() if err := fs.Remove(fsPath, true); err != nil { c.Log(logger.LevelError, "failed to remove directory %q: %+v", fsPath, err) return c.GetFsError(fs, err) } elapsed := time.Since(startTime).Nanoseconds() / 1000000 logger.CommandLog(rmdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) ExecuteActionNotification(c, operationRmdir, fsPath, virtualPath, "", "", "", 0, nil, elapsed, nil) //nolint:errcheck return nil } func (c *BaseConnection) doRecursiveRemoveDirEntry(virtualPath string, info os.FileInfo, recursion int) error { fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { return err } return c.doRecursiveRemove(fs, fsPath, virtualPath, info, recursion) } func (c *BaseConnection) doRecursiveRemove(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo, recursion int) error { if info.IsDir() { if recursion >= util.MaxRecursion { c.Log(logger.LevelError, "recursive rename failed, recursion too depth: %d", recursion) return util.ErrRecursionTooDeep } recursion++ lister, err := c.ListDir(virtualPath) if err != nil { return fmt.Errorf("unable to get lister for dir %q: %w", virtualPath, err) } defer lister.Close() for { entries, err := lister.Next(vfs.ListerBatchSize) finished := errors.Is(err, io.EOF) if err != nil && !finished { return fmt.Errorf("unable to get content for dir %q: %w", virtualPath, err) } for _, fi := range entries { targetPath := path.Join(virtualPath, fi.Name()) if err := c.doRecursiveRemoveDirEntry(targetPath, fi, recursion); err != nil { return err } } if finished { lister.Close() break } } return c.RemoveDir(virtualPath) } return c.RemoveFile(fs, fsPath, virtualPath, info) } // RemoveAll removes the specified path and any children it contains func (c *BaseConnection) RemoveAll(virtualPath string) error { fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { return err } fi, err := fs.Lstat(fsPath) if err != nil { c.Log(logger.LevelDebug, "failed to remove path %q: stat error: %+v", fsPath, err) return c.GetFsError(fs, err) } if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil { return err } return c.doRecursiveRemove(fs, fsPath, virtualPath, fi, 0) } return c.RemoveFile(fs, fsPath, virtualPath, fi) } func (c *BaseConnection) checkCopy(srcInfo, dstInfo os.FileInfo, virtualSource, virtualTarget string) error { _, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSource) if err != nil { return err } _, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTarget) if err != nil { return err } if srcInfo.IsDir() { if dstInfo != nil && !dstInfo.IsDir() { return fmt.Errorf("cannot overwrite file %q with dir %q: %w", virtualTarget, virtualSource, c.GetOpUnsupportedError()) } if util.IsDirOverlapped(virtualSource, virtualTarget, true, "/") { return fmt.Errorf("nested copy %q => %q is not supported: %w", virtualSource, virtualTarget, c.GetOpUnsupportedError()) } if util.IsDirOverlapped(fsSourcePath, fsTargetPath, true, c.User.FsConfig.GetPathSeparator()) { c.Log(logger.LevelWarn, "nested fs copy %q => %q not allowed", fsSourcePath, fsTargetPath) return fmt.Errorf("nested fs copy is not supported: %w", c.GetOpUnsupportedError()) } return nil } if dstInfo != nil && dstInfo.IsDir() { return fmt.Errorf("cannot overwrite file %q with dir %q: %w", virtualSource, virtualTarget, c.GetOpUnsupportedError()) } if c.IsSameResource(virtualSource, virtualTarget) { if fsSourcePath == fsTargetPath { return fmt.Errorf("the copy source and target cannot be the same: %w", c.GetOpUnsupportedError()) } } return nil } func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo) error { if !c.User.HasPerm(dataprovider.PermCopy, virtualSourcePath) || !c.User.HasPerm(dataprovider.PermCopy, virtualTargetPath) { return c.GetPermissionDeniedError() } if ok, _ := c.User.IsFileAllowed(virtualTargetPath); !ok { return fmt.Errorf("file %q is not allowed: %w", virtualTargetPath, c.GetPermissionDeniedError()) } if c.IsSameResource(virtualSourcePath, virtualTargetPath) { fs, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTargetPath) if err != nil { return err } if copier, ok := fs.(vfs.FsFileCopier); ok { _, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) if err != nil { return err } startTime := time.Now() numFiles, sizeDiff, err := copier.CopyFile(fsSourcePath, fsTargetPath, srcInfo) elapsed := time.Since(startTime).Nanoseconds() / 1000000 updateUserQuotaAfterFileWrite(c, virtualTargetPath, numFiles, sizeDiff) logger.CommandLog(copyLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", srcInfo.Size(), c.localAddr, c.remoteAddr, elapsed) ExecuteActionNotification(c, operationCopy, fsSourcePath, virtualSourcePath, fsTargetPath, virtualTargetPath, "", srcInfo.Size(), err, elapsed, nil) //nolint:errcheck return err } } reader, rCancelFn, err := getFileReader(c, virtualSourcePath) if err != nil { return fmt.Errorf("unable to get reader for path %q: %w", virtualSourcePath, err) } defer rCancelFn() defer reader.Close() writer, numFiles, truncatedSize, wCancelFn, err := getFileWriter(c, virtualTargetPath, srcInfo.Size()) if err != nil { return fmt.Errorf("unable to get writer for path %q: %w", virtualTargetPath, err) } defer wCancelFn() startTime := time.Now() _, err = io.Copy(writer, reader) return closeWriterAndUpdateQuota(writer, c, virtualSourcePath, virtualTargetPath, numFiles, truncatedSize, err, operationCopy, startTime) } func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo, createTargetDir bool, recursion int, ) error { if srcInfo.IsDir() { if recursion >= util.MaxRecursion { c.Log(logger.LevelError, "recursive copy failed, recursion too depth: %d", recursion) return util.ErrRecursionTooDeep } recursion++ if createTargetDir { if err := c.CreateDir(virtualTargetPath, false); err != nil { return fmt.Errorf("unable to create directory %q: %w", virtualTargetPath, err) } } lister, err := c.ListDir(virtualSourcePath) if err != nil { return fmt.Errorf("unable to get lister for dir %q: %w", virtualSourcePath, err) } defer lister.Close() for { entries, err := lister.Next(vfs.ListerBatchSize) finished := errors.Is(err, io.EOF) if err != nil && !finished { return fmt.Errorf("unable to get contents for dir %q: %w", virtualSourcePath, err) } if err := c.recursiveCopyEntries(virtualSourcePath, virtualTargetPath, entries, recursion); err != nil { return err } if finished { return nil } } } if !srcInfo.Mode().IsRegular() { c.Log(logger.LevelInfo, "skipping copy for non regular file %q", virtualSourcePath) return nil } return c.copyFile(virtualSourcePath, virtualTargetPath, srcInfo) } func (c *BaseConnection) recursiveCopyEntries(virtualSourcePath, virtualTargetPath string, entries []os.FileInfo, recursion int) error { for _, info := range entries { sourcePath := path.Join(virtualSourcePath, info.Name()) targetPath := path.Join(virtualTargetPath, info.Name()) targetInfo, err := c.DoStat(targetPath, 1, false) if err == nil { if info.IsDir() && targetInfo.IsDir() { c.Log(logger.LevelDebug, "target copy dir %q already exists", targetPath) continue } } if err != nil && !c.IsNotExistError(err) { return err } if err := c.checkCopy(info, targetInfo, sourcePath, targetPath); err != nil { return err } if err := c.doRecursiveCopy(sourcePath, targetPath, info, true, recursion); err != nil { if c.IsNotExistError(err) { c.Log(logger.LevelInfo, "skipping copy for source path %q: %v", sourcePath, err) continue } return err } } return nil } // Copy virtualSourcePath to virtualTargetPath func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error { copyFromSource := strings.HasSuffix(virtualSourcePath, "/") copyInTarget := strings.HasSuffix(virtualTargetPath, "/") virtualSourcePath = path.Clean(virtualSourcePath) virtualTargetPath = path.Clean(virtualTargetPath) if virtualSourcePath == virtualTargetPath { return fmt.Errorf("the copy source and target cannot be the same: %w", c.GetOpUnsupportedError()) } srcInfo, err := c.DoStat(virtualSourcePath, 1, false) if err != nil { return err } if srcInfo.Mode()&os.ModeSymlink != 0 { return fmt.Errorf("copying symlinks is not supported: %w", c.GetOpUnsupportedError()) } dstInfo, err := c.DoStat(virtualTargetPath, 1, false) if err == nil && !copyFromSource { copyInTarget = dstInfo.IsDir() } if err != nil && !c.IsNotExistError(err) { return err } destPath := virtualTargetPath if copyInTarget { destPath = path.Join(virtualTargetPath, path.Base(virtualSourcePath)) dstInfo, err = c.DoStat(destPath, 1, false) if err != nil && !c.IsNotExistError(err) { return err } } createTargetDir := dstInfo == nil || !dstInfo.IsDir() if err := c.checkCopy(srcInfo, dstInfo, virtualSourcePath, destPath); err != nil { return err } if err := c.CheckParentDirs(path.Dir(destPath)); err != nil { return err } stopKeepAlive := keepConnectionAlive(c, 2*time.Minute) defer stopKeepAlive() return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir, 0) } // Rename renames (moves) virtualSourcePath to virtualTargetPath func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) error { return c.renameInternal(virtualSourcePath, virtualTargetPath, false, vfs.CheckParentDir) } func (c *BaseConnection) renameInternal(virtualSourcePath, virtualTargetPath string, //nolint:gocyclo checkParentDestination bool, checks int, ) error { if virtualSourcePath == virtualTargetPath { return fmt.Errorf("the rename source and target cannot be the same: %w", c.GetOpUnsupportedError()) } fsSrc, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) if err != nil { return err } fsDst, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTargetPath) if err != nil { return err } startTime := time.Now() srcInfo, err := fsSrc.Lstat(fsSourcePath) if err != nil { return c.GetFsError(fsSrc, err) } if !c.isRenamePermitted(fsSrc, fsDst, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath, srcInfo) { return c.GetPermissionDeniedError() } initialSize := int64(-1) dstInfo, err := fsDst.Lstat(fsTargetPath) if err != nil && !fsDst.IsNotExist(err) { return err } if err == nil { checkParentDestination = false if dstInfo.IsDir() { c.Log(logger.LevelWarn, "attempted to rename %q overwriting an existing directory %q", fsSourcePath, fsTargetPath) return c.GetOpUnsupportedError() } // we are overwriting an existing file/symlink if dstInfo.Mode().IsRegular() { initialSize = dstInfo.Size() } if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(virtualTargetPath)) { c.Log(logger.LevelDebug, "renaming %q -> %q is not allowed. Target exists but the user %q"+ "has no overwrite permission", virtualSourcePath, virtualTargetPath, c.User.Username) return c.GetPermissionDeniedError() } } if srcInfo.IsDir() { if err := c.checkFolderRename(fsSrc, fsDst, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath, srcInfo); err != nil { return err } } if !c.hasSpaceForRename(fsSrc, virtualSourcePath, virtualTargetPath, initialSize, fsSourcePath, srcInfo) { c.Log(logger.LevelInfo, "denying cross rename due to space limit") return c.GetGenericError(ErrQuotaExceeded) } if checkParentDestination { c.CheckParentDirs(path.Dir(virtualTargetPath)) //nolint:errcheck } stopKeepAlive := keepConnectionAlive(c, 2*time.Minute) defer stopKeepAlive() files, size, err := fsDst.Rename(fsSourcePath, fsTargetPath, checks) if err != nil { c.Log(logger.LevelError, "failed to rename %q -> %q: %+v", fsSourcePath, fsTargetPath, err) return c.GetFsError(fsSrc, err) } vfs.SetPathPermissions(fsDst, fsTargetPath, c.User.GetUID(), c.User.GetGID()) elapsed := time.Since(startTime).Nanoseconds() / 1000000 c.updateQuotaAfterRename(fsDst, virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize, files, size) //nolint:errcheck logger.CommandLog(renameLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) ExecuteActionNotification(c, operationRename, fsSourcePath, virtualSourcePath, fsTargetPath, //nolint:errcheck virtualTargetPath, "", 0, nil, elapsed, nil) return nil } // CreateSymlink creates fsTargetPath as a symbolic link to fsSourcePath func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath string) error { var relativePath string if !path.IsAbs(virtualSourcePath) { relativePath = virtualSourcePath virtualSourcePath = path.Join(path.Dir(virtualTargetPath), relativePath) c.Log(logger.LevelDebug, "link relative path %q resolved as %q, target path %q", relativePath, virtualSourcePath, virtualTargetPath) } if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) { c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath) return c.GetOpUnsupportedError() } // we cannot have a cross folder request here so only one fs is enough fs, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) if err != nil { return err } fsTargetPath, err := fs.ResolvePath(virtualTargetPath) if err != nil { return c.GetFsError(fs, err) } if fs.GetRelativePath(fsSourcePath) == "/" { c.Log(logger.LevelError, "symlinking root dir is not allowed") return c.GetPermissionDeniedError() } if fs.GetRelativePath(fsTargetPath) == "/" { c.Log(logger.LevelError, "symlinking to root dir is not allowed") return c.GetPermissionDeniedError() } if !c.User.HasPerm(dataprovider.PermCreateSymlinks, path.Dir(virtualTargetPath)) { return c.GetPermissionDeniedError() } ok, policy := c.User.IsFileAllowed(virtualSourcePath) if !ok && policy == sdk.DenyPolicyHide { c.Log(logger.LevelError, "symlink source path %q is not allowed", virtualSourcePath) return c.GetNotExistError() } if ok, _ = c.User.IsFileAllowed(virtualTargetPath); !ok { c.Log(logger.LevelError, "symlink target path %q is not allowed", virtualTargetPath) return c.GetPermissionDeniedError() } if relativePath != "" { fsSourcePath = relativePath } startTime := time.Now() if err := fs.Symlink(fsSourcePath, fsTargetPath); err != nil { c.Log(logger.LevelError, "failed to create symlink %q -> %q: %+v", fsSourcePath, fsTargetPath, err) return c.GetFsError(fs, err) } elapsed := time.Since(startTime).Nanoseconds() / 1000000 logger.CommandLog(symlinkLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) return nil } func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFilePatterns, convertResult bool, ) (os.FileInfo, error) { // for some vfs we don't create intermediary folders so we cannot simply check // if virtualPath is a virtual folder. Allowing stat for hidden virtual folders // is by purpose. vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath)) if _, ok := vfolders[virtualPath]; ok { return vfs.NewFileInfo(virtualPath, true, 0, time.Unix(0, 0), false), nil } if checkFilePatterns && virtualPath != "/" { ok, policy := c.User.IsFileAllowed(virtualPath) if !ok && policy == sdk.DenyPolicyHide { return nil, c.GetNotExistError() } } var info os.FileInfo fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { return nil, err } if mode == 1 { info, err = fs.Lstat(c.getRealFsPath(fsPath)) } else { info, err = fs.Stat(c.getRealFsPath(fsPath)) } if err != nil { isNotExist := fs.IsNotExist(err) if isNotExist { // This is primarily useful for atomic storage backends, where files // become visible only after they are closed. However, since we may // be proxying (for example) an SFTP server backed by atomic // storage, and this search only inspects transfers active on the // current connection (typically just one), the check is inexpensive // and safe to perform unconditionally. if info, err := c.getInfoForOngoingUpload(fsPath); err == nil { return info, nil } } if !isNotExist { c.Log(logger.LevelWarn, "stat error for path %q: %+v", virtualPath, err) } return nil, c.GetFsError(fs, err) } if convertResult && vfs.IsCryptOsFs(fs) { info = fs.(*vfs.CryptFs).ConvertFileInfo(info) } return info, nil } // DoStat execute a Stat if mode = 0, Lstat if mode = 1 func (c *BaseConnection) DoStat(virtualPath string, mode int, checkFilePatterns bool) (os.FileInfo, error) { return c.doStatInternal(virtualPath, mode, checkFilePatterns, true) } func (c *BaseConnection) createDirIfMissing(name string) error { _, err := c.DoStat(name, 0, false) if c.IsNotExistError(err) { return c.CreateDir(name, false) } return err } func (c *BaseConnection) ignoreSetStat(fs vfs.Fs) bool { if Config.SetstatMode == 1 { return true } if Config.SetstatMode == 2 && !vfs.IsLocalOrSFTPFs(fs) && !vfs.IsCryptOsFs(fs) { return true } return false } func (c *BaseConnection) handleChmod(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { if !c.User.HasPerm(dataprovider.PermChmod, pathForPerms) { return c.GetPermissionDeniedError() } if c.ignoreSetStat(fs) { return nil } startTime := time.Now() if err := fs.Chmod(c.getRealFsPath(fsPath), attributes.Mode); err != nil { c.Log(logger.LevelError, "failed to chmod path %q, mode: %v, err: %+v", fsPath, attributes.Mode.String(), err) return c.GetFsError(fs, err) } elapsed := time.Since(startTime).Nanoseconds() / 1000000 logger.CommandLog(chmodLogSender, fsPath, "", c.User.Username, attributes.Mode.String(), c.ID, c.protocol, -1, -1, "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) return nil } func (c *BaseConnection) handleChown(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { if !c.User.HasPerm(dataprovider.PermChown, pathForPerms) { return c.GetPermissionDeniedError() } if c.ignoreSetStat(fs) { return nil } startTime := time.Now() if err := fs.Chown(c.getRealFsPath(fsPath), attributes.UID, attributes.GID); err != nil { c.Log(logger.LevelError, "failed to chown path %q, uid: %v, gid: %v, err: %+v", fsPath, attributes.UID, attributes.GID, err) return c.GetFsError(fs, err) } elapsed := time.Since(startTime).Nanoseconds() / 1000000 logger.CommandLog(chownLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, attributes.UID, attributes.GID, "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) return nil } func (c *BaseConnection) handleChtimes(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { if !c.User.HasPerm(dataprovider.PermChtimes, pathForPerms) { return c.GetPermissionDeniedError() } if Config.SetstatMode == 1 { return nil } startTime := time.Now() isUploading := c.setTimes(fsPath, attributes.Atime, attributes.Mtime) if err := fs.Chtimes(c.getRealFsPath(fsPath), attributes.Atime, attributes.Mtime, isUploading); err != nil { c.setTimes(fsPath, time.Time{}, time.Time{}) if errors.Is(err, vfs.ErrVfsUnsupported) && Config.SetstatMode == 2 { return nil } c.Log(logger.LevelError, "failed to chtimes for path %q, access time: %v, modification time: %v, err: %+v", fsPath, attributes.Atime, attributes.Mtime, err) return c.GetFsError(fs, err) } elapsed := time.Since(startTime).Nanoseconds() / 1000000 accessTimeString := attributes.Atime.Format(chtimesFormat) modificationTimeString := attributes.Mtime.Format(chtimesFormat) logger.CommandLog(chtimesLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, accessTimeString, modificationTimeString, "", -1, c.localAddr, c.remoteAddr, elapsed) return nil } // SetStat set StatAttributes for the specified fsPath func (c *BaseConnection) SetStat(virtualPath string, attributes *StatAttributes) error { if ok, policy := c.User.IsFileAllowed(virtualPath); !ok { return c.GetErrorForDeniedFile(policy) } fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { return err } pathForPerms := path.Dir(virtualPath) if attributes.Flags&StatAttrTimes != 0 { if err = c.handleChtimes(fs, fsPath, pathForPerms, attributes); err != nil { return err } } if attributes.Flags&StatAttrPerms != 0 { if err = c.handleChmod(fs, fsPath, pathForPerms, attributes); err != nil { return err } } if attributes.Flags&StatAttrUIDGID != 0 { if err = c.handleChown(fs, fsPath, pathForPerms, attributes); err != nil { return err } } if attributes.Flags&StatAttrSize != 0 { if !c.User.HasPerm(dataprovider.PermOverwrite, pathForPerms) { return c.GetPermissionDeniedError() } startTime := time.Now() if err = c.truncateFile(fs, fsPath, virtualPath, attributes.Size); err != nil { c.Log(logger.LevelError, "failed to truncate path %q, size: %v, err: %+v", fsPath, attributes.Size, err) return c.GetFsError(fs, err) } elapsed := time.Since(startTime).Nanoseconds() / 1000000 logger.CommandLog(truncateLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", attributes.Size, c.localAddr, c.remoteAddr, elapsed) } return nil } func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, size int64) error { // check first if we have an open transfer for the given path and try to truncate the file already opened // if we found no transfer we truncate by path. var initialSize int64 var err error initialSize, err = c.truncateOpenHandle(fsPath, size) if err == errNoTransfer { c.Log(logger.LevelDebug, "file path %q not found in active transfers, execute trucate by path", fsPath) var info os.FileInfo info, err = fs.Stat(fsPath) if err != nil { return err } initialSize = info.Size() err = fs.Truncate(fsPath, size) } if err == nil && vfs.HasTruncateSupport(fs) { sizeDiff := initialSize - size vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) if err == nil { dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -sizeDiff, false) } else { dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck } } return err } func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath, virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo, ) error { if !c.User.HasPermissionsInside(virtualSourcePath) && !c.User.HasPermissionsInside(virtualTargetPath) { if !c.isRenamePermitted(fsSrc, fsDst, sourcePath, targetPath, virtualSourcePath, virtualTargetPath, srcInfo) { c.Log(logger.LevelInfo, "rename %q -> %q is not allowed, virtual destination path: %q", sourcePath, targetPath, virtualTargetPath) return c.GetPermissionDeniedError() } // if all rename permissions are granted we have finished, otherwise we have to walk // because we could have the rename dir permission but not the rename file and the dir to // rename could contain files if c.User.HasPermsRenameAll(path.Dir(virtualSourcePath)) && c.User.HasPermsRenameAll(path.Dir(virtualTargetPath)) { return nil } } return fsSrc.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error { if err != nil { return c.GetFsError(fsSrc, err) } if walkedPath != sourcePath && !vfs.IsRenameAtomic(fsSrc) && Config.RenameMode == 0 { c.Log(logger.LevelInfo, "cannot rename non empty directory %q on this filesystem", virtualSourcePath) return c.GetOpUnsupportedError() } dstPath := strings.Replace(walkedPath, sourcePath, targetPath, 1) virtualSrcPath := fsSrc.GetRelativePath(walkedPath) virtualDstPath := fsDst.GetRelativePath(dstPath) if !c.isRenamePermitted(fsSrc, fsDst, walkedPath, dstPath, virtualSrcPath, virtualDstPath, info) { c.Log(logger.LevelInfo, "rename %q -> %q is not allowed, virtual destination path: %q", walkedPath, dstPath, virtualDstPath) return c.GetPermissionDeniedError() } return nil }) } func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool { if c.User.HasPermsRenameAll(path.Dir(virtualSourcePath)) && c.User.HasPermsRenameAll(path.Dir(virtualTargetPath)) { return true } if fi == nil { // we don't know if this is a file or a directory and we don't have all the rename perms, return false return false } if fi.IsDir() { perms := []string{ dataprovider.PermRenameDirs, dataprovider.PermRename, } return c.User.HasAnyPerm(perms, path.Dir(virtualSourcePath)) && c.User.HasAnyPerm(perms, path.Dir(virtualTargetPath)) } // file or symlink perms := []string{ dataprovider.PermRenameFiles, dataprovider.PermRename, } return c.User.HasAnyPerm(perms, path.Dir(virtualSourcePath)) && c.User.HasAnyPerm(perms, path.Dir(virtualTargetPath)) } func (c *BaseConnection) checkFolderRename(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo) error { if util.IsDirOverlapped(virtualSourcePath, virtualTargetPath, true, "/") { c.Log(logger.LevelDebug, "renaming the folder %q->%q is not supported: nested folders", virtualSourcePath, virtualTargetPath) return fmt.Errorf("nested rename %q => %q is not supported: %w", virtualSourcePath, virtualTargetPath, c.GetOpUnsupportedError()) } if util.IsDirOverlapped(fsSourcePath, fsTargetPath, true, c.User.FsConfig.GetPathSeparator()) { c.Log(logger.LevelDebug, "renaming the folder %q->%q is not supported: nested fs folders", fsSourcePath, fsTargetPath) return fmt.Errorf("nested fs rename %q => %q is not supported: %w", fsSourcePath, fsTargetPath, c.GetOpUnsupportedError()) } if c.User.HasVirtualFoldersInside(virtualSourcePath) { c.Log(logger.LevelDebug, "renaming the folder %q is not supported: it has virtual folders inside it", virtualSourcePath) return fmt.Errorf("folder %q has virtual folders inside it: %w", virtualSourcePath, c.GetOpUnsupportedError()) } if c.User.HasVirtualFoldersInside(virtualTargetPath) { c.Log(logger.LevelDebug, "renaming the folder %q is not supported, the target %q has virtual folders inside it", virtualSourcePath, virtualTargetPath) return fmt.Errorf("folder %q has virtual folders inside it: %w", virtualTargetPath, c.GetOpUnsupportedError()) } if err := c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath, srcInfo); err != nil { c.Log(logger.LevelDebug, "error checking recursive permissions before renaming %q: %+v", fsSourcePath, err) return err } return nil } func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo, ) bool { if !c.IsSameResource(virtualSourcePath, virtualTargetPath) { c.Log(logger.LevelInfo, "rename %q->%q is not allowed: the paths must be on the same resource", virtualSourcePath, virtualTargetPath) return false } if c.User.IsMappedPath(fsSourcePath) && vfs.IsLocalOrCryptoFs(fsSrc) { c.Log(logger.LevelWarn, "renaming a directory mapped as virtual folder is not allowed: %q", fsSourcePath) return false } if c.User.IsMappedPath(fsTargetPath) && vfs.IsLocalOrCryptoFs(fsDst) { c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %q", fsTargetPath) return false } if virtualSourcePath == "/" || virtualTargetPath == "/" || fsSrc.GetRelativePath(fsSourcePath) == "/" { c.Log(logger.LevelWarn, "renaming root dir is not allowed") return false } if c.User.IsVirtualFolder(virtualSourcePath) || c.User.IsVirtualFolder(virtualTargetPath) { c.Log(logger.LevelWarn, "renaming a virtual folder is not allowed") return false } isSrcAllowed, _ := c.User.IsFileAllowed(virtualSourcePath) isDstAllowed, _ := c.User.IsFileAllowed(virtualTargetPath) if !isSrcAllowed || !isDstAllowed { c.Log(logger.LevelDebug, "renaming source: %q to target: %q not allowed", virtualSourcePath, virtualTargetPath) return false } return c.hasRenamePerms(virtualSourcePath, virtualTargetPath, srcInfo) } func (c *BaseConnection) hasSpaceForRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath string, initialSize int64, sourcePath string, srcInfo os.FileInfo) bool { if dataprovider.GetQuotaTracking() == 0 { return true } sourceFolder, errSrc := c.User.GetVirtualFolderForPath(path.Dir(virtualSourcePath)) dstFolder, errDst := c.User.GetVirtualFolderForPath(path.Dir(virtualTargetPath)) if errSrc != nil && errDst != nil { // rename inside the user home dir return true } if errSrc == nil && errDst == nil { // rename between virtual folders if sourceFolder.Name == dstFolder.Name { // rename inside the same virtual folder return true } } if errSrc != nil && dstFolder.IsIncludedInUserQuota() { // rename between user root dir and a virtual folder included in user quota return true } if errDst != nil && sourceFolder.IsIncludedInUserQuota() { // rename between a virtual folder included in user quota and the user root dir return true } quotaResult, _ := c.HasSpace(true, false, virtualTargetPath) if quotaResult.HasSpace && quotaResult.QuotaSize == 0 && quotaResult.QuotaFiles == 0 { // no quota restrictions return true } return c.hasSpaceForCrossRename(fs, quotaResult, initialSize, sourcePath, srcInfo) } // hasSpaceForCrossRename checks the quota after a rename between different folders func (c *BaseConnection) hasSpaceForCrossRename(fs vfs.Fs, quotaResult vfs.QuotaCheckResult, initialSize int64, sourcePath string, srcInfo os.FileInfo, ) bool { if !quotaResult.HasSpace && initialSize == -1 { // we are over quota and this is not a file replace return false } var sizeDiff int64 var filesDiff int var err error if srcInfo.Mode().IsRegular() { sizeDiff = srcInfo.Size() filesDiff = 1 if initialSize != -1 { sizeDiff -= initialSize filesDiff = 0 } } else if srcInfo.IsDir() { filesDiff, sizeDiff, err = fs.GetDirSize(sourcePath) if err != nil { c.Log(logger.LevelError, "cross rename denied, error getting size for directory %q: %v", sourcePath, err) return false } } if !quotaResult.HasSpace && initialSize != -1 { // we are over quota but we are overwriting an existing file so we check if the quota size after the rename is ok if quotaResult.QuotaSize == 0 { return true } c.Log(logger.LevelDebug, "cross rename overwrite, source %q, used size %d, size to add %d", sourcePath, quotaResult.UsedSize, sizeDiff) quotaResult.UsedSize += sizeDiff return quotaResult.GetRemainingSize() >= 0 } if quotaResult.QuotaFiles > 0 { remainingFiles := quotaResult.GetRemainingFiles() c.Log(logger.LevelDebug, "cross rename, source %q remaining file %d to add %d", sourcePath, remainingFiles, filesDiff) if remainingFiles < filesDiff { return false } } if quotaResult.QuotaSize > 0 { remainingSize := quotaResult.GetRemainingSize() c.Log(logger.LevelDebug, "cross rename, source %q remaining size %d to add %d", srcInfo.Name(), remainingSize, sizeDiff) if remainingSize < sizeDiff { return false } } return true } // GetMaxWriteSize returns the allowed size for an upload or an error // if no enough size is available for a resume/append func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isResume bool, fileSize int64, isUploadResumeSupported bool, ) (int64, error) { maxWriteSize := quotaResult.GetRemainingSize() if isResume { if !isUploadResumeSupported { return 0, c.GetOpUnsupportedError() } if c.User.Filters.MaxUploadFileSize > 0 && c.User.Filters.MaxUploadFileSize <= fileSize { return 0, c.GetQuotaExceededError() } if c.User.Filters.MaxUploadFileSize > 0 { maxUploadSize := c.User.Filters.MaxUploadFileSize - fileSize if maxUploadSize < maxWriteSize || maxWriteSize == 0 { maxWriteSize = maxUploadSize } } } else { if maxWriteSize > 0 { maxWriteSize += fileSize } if c.User.Filters.MaxUploadFileSize > 0 && (c.User.Filters.MaxUploadFileSize < maxWriteSize || maxWriteSize == 0) { maxWriteSize = c.User.Filters.MaxUploadFileSize } } return maxWriteSize, nil } // GetTransferQuota returns the data transfers quota func (c *BaseConnection) GetTransferQuota() dataprovider.TransferQuota { result, _, _ := c.checkUserQuota() return result } func (c *BaseConnection) checkUserQuota() (dataprovider.TransferQuota, int, int64) { ul, dl, total := c.User.GetDataTransferLimits() result := dataprovider.TransferQuota{ ULSize: ul, DLSize: dl, TotalSize: total, AllowedULSize: 0, AllowedDLSize: 0, AllowedTotalSize: 0, } if !c.User.HasTransferQuotaRestrictions() { return result, -1, -1 } usedFiles, usedSize, usedULSize, usedDLSize, err := dataprovider.GetUsedQuota(c.User.Username) if err != nil { c.Log(logger.LevelError, "error getting used quota for %q: %v", c.User.Username, err) result.AllowedTotalSize = -1 return result, -1, -1 } if result.TotalSize > 0 { result.AllowedTotalSize = result.TotalSize - (usedULSize + usedDLSize) } if result.ULSize > 0 { result.AllowedULSize = result.ULSize - usedULSize } if result.DLSize > 0 { result.AllowedDLSize = result.DLSize - usedDLSize } return result, usedFiles, usedSize } // HasSpace checks user's quota usage func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string) (vfs.QuotaCheckResult, dataprovider.TransferQuota, ) { result := vfs.QuotaCheckResult{ HasSpace: true, AllowedSize: 0, AllowedFiles: 0, UsedSize: 0, UsedFiles: 0, QuotaSize: 0, QuotaFiles: 0, } if dataprovider.GetQuotaTracking() == 0 { return result, dataprovider.TransferQuota{} } transferQuota, usedFiles, usedSize := c.checkUserQuota() var err error var vfolder vfs.VirtualFolder vfolder, err = c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil && !vfolder.IsIncludedInUserQuota() { if vfolder.HasNoQuotaRestrictions(checkFiles) && !getUsage { return result, transferQuota } result.QuotaSize = vfolder.QuotaSize result.QuotaFiles = vfolder.QuotaFiles result.UsedFiles, result.UsedSize, err = dataprovider.GetUsedVirtualFolderQuota(vfolder.Name) } else { if c.User.HasNoQuotaRestrictions(checkFiles) && !getUsage { return result, transferQuota } result.QuotaSize = c.User.QuotaSize result.QuotaFiles = c.User.QuotaFiles if usedSize == -1 { result.UsedFiles, result.UsedSize, _, _, err = dataprovider.GetUsedQuota(c.User.Username) } else { err = nil result.UsedFiles = usedFiles result.UsedSize = usedSize } } if err != nil { c.Log(logger.LevelError, "error getting used quota for %q request path %q: %v", c.User.Username, requestPath, err) result.HasSpace = false return result, transferQuota } result.AllowedFiles = result.QuotaFiles - result.UsedFiles result.AllowedSize = result.QuotaSize - result.UsedSize if (checkFiles && result.QuotaFiles > 0 && result.UsedFiles >= result.QuotaFiles) || (result.QuotaSize > 0 && result.UsedSize >= result.QuotaSize) { c.Log(logger.LevelDebug, "quota exceed for user %q, request path %q, num files: %d/%d, size: %d/%d check files: %t", c.User.Username, requestPath, result.UsedFiles, result.QuotaFiles, result.UsedSize, result.QuotaSize, checkFiles) result.HasSpace = false return result, transferQuota } return result, transferQuota } // IsSameResource returns true if source and target paths are on the same resource func (c *BaseConnection) IsSameResource(virtualSourcePath, virtualTargetPath string) bool { sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath) dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath) if errSrc != nil && errDst != nil { return true } if errSrc == nil && errDst == nil { if sourceFolder.Name == dstFolder.Name { return true } // we have different folders, check if they point to the same resource return sourceFolder.FsConfig.IsSameResource(dstFolder.FsConfig) } if errSrc == nil { return sourceFolder.FsConfig.IsSameResource(c.User.FsConfig) } return dstFolder.FsConfig.IsSameResource(c.User.FsConfig) } func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetPath string) bool { sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath) dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath) if errSrc != nil && errDst != nil { return false } if errSrc == nil && errDst == nil { return sourceFolder.Name != dstFolder.Name } return true } func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { if sourceFolder.Name == dstFolder.Name { // both files are inside the same virtual folder if initialSize != -1 { dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, -numFiles, -initialSize, false) } return } // files are inside different virtual folders dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false) if initialSize == -1 { dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false) return } // we cannot have a directory here, initialSize != -1 only for files dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false) } func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { // move between a virtual folder and the user home dir dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false) if initialSize == -1 { dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck return } // we cannot have a directory here, initialSize != -1 only for files dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck } func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { // move between the user home dir and a virtual folder dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck if initialSize == -1 { dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false) return } // we cannot have a directory here, initialSize != -1 only for files dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false) } func (c *BaseConnection) updateQuotaAfterRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath, targetPath string, initialSize int64, numFiles int, filesSize int64, ) error { if dataprovider.GetQuotaTracking() == 0 { return nil } // we don't allow to overwrite an existing directory so targetPath can be: // - a new file, a symlink is as a new file here // - a file overwriting an existing one // - a new directory // initialSize != -1 only when overwriting files sourceFolder, errSrc := c.User.GetVirtualFolderForPath(path.Dir(virtualSourcePath)) dstFolder, errDst := c.User.GetVirtualFolderForPath(path.Dir(virtualTargetPath)) if errSrc != nil && errDst != nil { // both files are contained inside the user home dir if initialSize != -1 { // we cannot have a directory here, we are overwriting an existing file // we need to subtract the size of the overwritten file from the user quota dataprovider.UpdateUserQuota(&c.User, -1, -initialSize, false) //nolint:errcheck } return nil } if filesSize == -1 { // fs.Rename didn't return the affected files/sizes, we need to calculate them numFiles = 1 if fi, err := fs.Stat(targetPath); err == nil { if fi.Mode().IsDir() { numFiles, filesSize, err = fs.GetDirSize(targetPath) if err != nil { c.Log(logger.LevelError, "failed to update quota after rename, error scanning moved folder %q: %+v", targetPath, err) return err } } else { filesSize = fi.Size() } } else { c.Log(logger.LevelError, "failed to update quota after renaming, file %q stat error: %+v", targetPath, err) return err } c.Log(logger.LevelDebug, "calculated renamed files: %d, size: %d bytes", numFiles, filesSize) } else { c.Log(logger.LevelDebug, "returned renamed files: %d, size: %d bytes", numFiles, filesSize) } if errSrc == nil && errDst == nil { c.updateQuotaMoveBetweenVFolders(&sourceFolder, &dstFolder, initialSize, filesSize, numFiles) } if errSrc == nil && errDst != nil { c.updateQuotaMoveFromVFolder(&sourceFolder, initialSize, filesSize, numFiles) } if errSrc != nil && errDst == nil { c.updateQuotaMoveToVFolder(&dstFolder, initialSize, filesSize, numFiles) } return nil } // IsNotExistError returns true if the specified fs error is not exist for the connection protocol func (c *BaseConnection) IsNotExistError(err error) bool { switch c.protocol { case ProtocolSFTP: return errors.Is(err, sftp.ErrSSHFxNoSuchFile) case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention: return errors.Is(err, os.ErrNotExist) default: return errors.Is(err, ErrNotExist) } } // GetErrorForDeniedFile return permission denied or not exist error based on the specified policy func (c *BaseConnection) GetErrorForDeniedFile(policy int) error { switch policy { case sdk.DenyPolicyHide: return c.GetNotExistError() default: return c.GetPermissionDeniedError() } } // GetPermissionDeniedError returns an appropriate permission denied error for the connection protocol func (c *BaseConnection) GetPermissionDeniedError() error { return getPermissionDeniedError(c.protocol) } // GetNotExistError returns an appropriate not exist error for the connection protocol func (c *BaseConnection) GetNotExistError() error { switch c.protocol { case ProtocolSFTP: return sftp.ErrSSHFxNoSuchFile case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention: return os.ErrNotExist default: return ErrNotExist } } // GetOpUnsupportedError returns an appropriate operation not supported error for the connection protocol func (c *BaseConnection) GetOpUnsupportedError() error { switch c.protocol { case ProtocolSFTP: return sftp.ErrSSHFxOpUnsupported default: return ErrOpUnsupported } } func getQuotaExceededError(protocol string) error { switch protocol { case ProtocolSFTP: return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, ErrQuotaExceeded) case ProtocolFTP: return ftpserver.ErrStorageExceeded default: return ErrQuotaExceeded } } func getReadQuotaExceededError(protocol string) error { switch protocol { case ProtocolSFTP: return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, ErrReadQuotaExceeded) default: return ErrReadQuotaExceeded } } // GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol func (c *BaseConnection) GetQuotaExceededError() error { return getQuotaExceededError(c.protocol) } // GetReadQuotaExceededError returns an appropriate read quota limit exceeded error for the connection protocol func (c *BaseConnection) GetReadQuotaExceededError() error { return getReadQuotaExceededError(c.protocol) } // IsQuotaExceededError returns true if the given error is a quota exceeded error func (c *BaseConnection) IsQuotaExceededError(err error) bool { switch c.protocol { case ProtocolSFTP: if err == nil { return false } if errors.Is(err, ErrQuotaExceeded) { return true } return errors.Is(err, sftp.ErrSSHFxFailure) && strings.Contains(err.Error(), ErrQuotaExceeded.Error()) case ProtocolFTP: return errors.Is(err, ftpserver.ErrStorageExceeded) || errors.Is(err, ErrQuotaExceeded) default: return errors.Is(err, ErrQuotaExceeded) } } func isSFTPGoError(err error) bool { return errors.Is(err, ErrPermissionDenied) || errors.Is(err, ErrNotExist) || errors.Is(err, ErrOpUnsupported) || errors.Is(err, ErrQuotaExceeded) || errors.Is(err, ErrReadQuotaExceeded) || errors.Is(err, vfs.ErrStorageSizeUnavailable) || errors.Is(err, ErrShuttingDown) } // GetGenericError returns an appropriate generic error for the connection protocol func (c *BaseConnection) GetGenericError(err error) error { switch c.protocol { case ProtocolSFTP: if errors.Is(err, vfs.ErrStorageSizeUnavailable) || errors.Is(err, ErrOpUnsupported) || errors.Is(err, sftp.ErrSSHFxOpUnsupported) { return fmt.Errorf("%w: %w", sftp.ErrSSHFxOpUnsupported, err) } if isSFTPGoError(err) { return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, err) } if err != nil { var pathError *fs.PathError if errors.As(err, &pathError) { c.Log(logger.LevelError, "generic path error: %+v", pathError) return fmt.Errorf("%w: %v %v", sftp.ErrSSHFxFailure, pathError.Op, pathError.Err.Error()) } c.Log(logger.LevelError, "generic error: %+v", err) } return sftp.ErrSSHFxFailure default: if isSFTPGoError(err) { return err } c.Log(logger.LevelError, "generic error: %+v", err) return ErrGenericFailure } } // GetFsError converts a filesystem error to a protocol error func (c *BaseConnection) GetFsError(fs vfs.Fs, err error) error { if fs.IsNotExist(err) { return c.GetNotExistError() } else if fs.IsPermission(err) { return c.GetPermissionDeniedError() } else if fs.IsNotSupported(err) { return c.GetOpUnsupportedError() } else if err != nil { return c.GetGenericError(err) } return nil } func (c *BaseConnection) getNotificationStatus(err error) int { if err == nil { return 1 } if c.IsQuotaExceededError(err) { return 3 } return 2 } // GetFsAndResolvedPath returns the fs and the fs path matching virtualPath func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, string, error) { fs, err := c.User.GetFilesystemForPath(virtualPath, c.ID) if err != nil { if c.protocol == ProtocolWebDAV && strings.Contains(err.Error(), vfs.ErrSFTPLoop.Error()) { // if there is an SFTP loop we return a permission error, for WebDAV, so the problematic folder // will not be listed return nil, "", util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) } return nil, "", c.GetGenericError(err) } if isShuttingDown.Load() { return nil, "", c.GetFsError(fs, ErrShuttingDown) } fsPath, err := fs.ResolvePath(virtualPath) if err != nil { return nil, "", c.GetFsError(fs, err) } return fs, fsPath, nil } // DirListerAt defines a directory lister implementing the ListAt method. type DirListerAt struct { virtualPath string conn *BaseConnection fs vfs.Fs info []os.FileInfo mu sync.Mutex lister vfs.DirLister } // Prepend adds the given os.FileInfo as first element of the internal cache func (l *DirListerAt) Prepend(fi os.FileInfo) { l.mu.Lock() defer l.mu.Unlock() l.info = slices.Insert(l.info, 0, fi) } // ListAt implements sftp.ListerAt func (l *DirListerAt) ListAt(f []os.FileInfo, _ int64) (int, error) { l.mu.Lock() defer l.mu.Unlock() if len(f) == 0 { return 0, errors.New("invalid ListAt destination, zero size") } if len(f) <= len(l.info) { files := make([]os.FileInfo, 0, len(f)) for idx := range l.info { files = append(files, l.info[idx]) if len(files) == len(f) { l.info = l.info[idx+1:] n := copy(f, files) return n, nil } } } limit := len(f) - len(l.info) files, err := l.Next(limit) n := copy(f, files) return n, err } // Next reads the directory and returns a slice of up to n FileInfo values. func (l *DirListerAt) Next(limit int) ([]os.FileInfo, error) { for { files, err := l.lister.Next(limit) if err != nil && !errors.Is(err, io.EOF) { l.conn.Log(logger.LevelDebug, "error retrieving directory entries: %+v", err) return files, l.conn.GetFsError(l.fs, err) } files = l.conn.User.FilterListDir(files, l.virtualPath) if len(l.info) > 0 { files = slices.Concat(l.info, files) l.info = nil } if err != nil || len(files) > 0 { return files, err } } } // Close closes the DirListerAt func (l *DirListerAt) Close() error { l.mu.Lock() defer l.mu.Unlock() return l.lister.Close() } func (l *DirListerAt) convertError(err error) error { if errors.Is(err, io.EOF) { return nil } return err } func getPermissionDeniedError(protocol string) error { switch protocol { case ProtocolSFTP: return sftp.ErrSSHFxPermissionDenied case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention: return os.ErrPermission default: return ErrPermissionDenied } } func keepConnectionAlive(c *BaseConnection, interval time.Duration) func() { var timer *time.Timer var closed atomic.Bool task := func() { c.UpdateLastActivity() if !closed.Load() { timer.Reset(interval) } } timer = time.AfterFunc(interval, task) return func() { closed.Store(true) timer.Stop() } } ================================================ FILE: internal/common/connection_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "errors" "fmt" "io" "os" "path" "path/filepath" "runtime" "slices" "strconv" "testing" "time" "github.com/pkg/sftp" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) var ( errWalkDir = errors.New("err walk dir") ) // MockOsFs mockable OsFs type MockOsFs struct { vfs.Fs hasVirtualFolders bool name string err error } // Name returns the name for the Fs implementation func (fs *MockOsFs) Name() string { if fs.name != "" { return fs.name } return "mockOsFs" } // HasVirtualFolders returns true if folders are emulated func (fs *MockOsFs) HasVirtualFolders() bool { return fs.hasVirtualFolders } func (fs *MockOsFs) IsUploadResumeSupported() bool { return !fs.hasVirtualFolders } func (fs *MockOsFs) Chtimes(_ string, _, _ time.Time, _ bool) error { return vfs.ErrVfsUnsupported } func (fs *MockOsFs) Lstat(name string) (os.FileInfo, error) { if fs.err != nil { return nil, fs.err } return fs.Fs.Lstat(name) } // Walk returns a duplicate path for testing func (fs *MockOsFs) Walk(_ string, walkFn filepath.WalkFunc) error { if fs.err == errWalkDir { walkFn("fsdpath", vfs.NewFileInfo("dpath", true, 0, time.Now(), false), nil) //nolint:errcheck return walkFn("fsdpath", vfs.NewFileInfo("dpath", true, 0, time.Now(), false), nil) //nolint:errcheck } walkFn("fsfpath", vfs.NewFileInfo("fpath", false, 0, time.Now(), false), nil) //nolint:errcheck return fs.err } func newMockOsFs(hasVirtualFolders bool, connectionID, rootDir, name string, err error) vfs.Fs { return &MockOsFs{ Fs: vfs.NewOsFs(connectionID, rootDir, "", nil), name: name, hasVirtualFolders: hasVirtualFolders, err: err, } } func TestRemoveErrors(t *testing.T) { mappedPath := filepath.Join(os.TempDir(), "map") homePath := filepath.Join(os.TempDir(), "home") user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "remove_errors_user", HomeDir: homePath, }, VirtualFolders: []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: filepath.Base(mappedPath), MappedPath: mappedPath, }, VirtualPath: "/virtualpath", }, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := vfs.NewOsFs("", os.TempDir(), "", nil) conn := NewBaseConnection("", ProtocolFTP, "", "", user) err := conn.IsRemoveDirAllowed(fs, mappedPath, "/virtualpath1") if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission denied") } err = conn.RemoveFile(fs, filepath.Join(homePath, "missing_file"), "/missing_file", vfs.NewFileInfo("info", false, 100, time.Now(), false)) assert.Error(t, err) } func TestSetStatMode(t *testing.T) { oldSetStatMode := Config.SetstatMode Config.SetstatMode = 1 fakePath := "fake path" user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: os.TempDir(), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := newMockOsFs(true, "", user.GetHomeDir(), "", nil) conn := NewBaseConnection("", ProtocolWebDAV, "", "", user) err := conn.handleChmod(fs, fakePath, fakePath, nil) assert.NoError(t, err) err = conn.handleChown(fs, fakePath, fakePath, nil) assert.NoError(t, err) err = conn.handleChtimes(fs, fakePath, fakePath, nil) assert.NoError(t, err) Config.SetstatMode = 2 err = conn.handleChmod(fs, fakePath, fakePath, nil) assert.NoError(t, err) err = conn.handleChtimes(fs, fakePath, fakePath, &StatAttributes{ Atime: time.Now(), Mtime: time.Now(), }) assert.NoError(t, err) Config.SetstatMode = oldSetStatMode } func TestRecursiveRenameWalkError(t *testing.T) { fs := vfs.NewOsFs("", filepath.Clean(os.TempDir()), "", nil) conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Permissions: map[string][]string{ "/": {dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDownload, dataprovider.PermRenameDirs}, }, }, }) err := conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"), filepath.Join(os.TempDir(), "/target"), "/source", "/target", vfs.NewFileInfo("source", true, 0, time.Now(), false)) assert.ErrorIs(t, err, os.ErrNotExist) fs = newMockOsFs(false, "mockID", filepath.Clean(os.TempDir()), "S3Fs", errWalkDir) err = conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"), filepath.Join(os.TempDir(), "/target"), "/source", "/target", vfs.NewFileInfo("source", true, 0, time.Now(), false)) if assert.Error(t, err) { assert.Equal(t, err.Error(), conn.GetOpUnsupportedError().Error()) } conn.User.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDownload, dataprovider.PermRenameFiles} // no dir rename permission, the quick check path returns permission error without walking err = conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"), filepath.Join(os.TempDir(), "/target"), "/source", "/target", vfs.NewFileInfo("source", true, 0, time.Now(), false)) if assert.Error(t, err) { assert.EqualError(t, err, conn.GetPermissionDeniedError().Error()) } } func TestCrossRenameFsErrors(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } fs := vfs.NewOsFs("", os.TempDir(), "", nil) conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{}) dirPath := filepath.Join(os.TempDir(), "d") err := os.Mkdir(dirPath, os.ModePerm) assert.NoError(t, err) err = os.Chmod(dirPath, 0001) assert.NoError(t, err) srcInfo := vfs.NewFileInfo(filepath.Base(dirPath), true, 0, time.Now(), false) res := conn.hasSpaceForCrossRename(fs, vfs.QuotaCheckResult{}, 1, dirPath, srcInfo) assert.False(t, res) err = os.Chmod(dirPath, os.ModePerm) assert.NoError(t, err) err = os.Remove(dirPath) assert.NoError(t, err) } func TestRenameVirtualFolders(t *testing.T) { vdir := "/avdir" u := dataprovider.User{} u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: "name", MappedPath: "mappedPath", }, VirtualPath: vdir, }) fs := vfs.NewOsFs("", os.TempDir(), "", nil) conn := NewBaseConnection("", ProtocolFTP, "", "", u) res := conn.isRenamePermitted(fs, fs, "source", "target", vdir, "vdirtarget", nil) assert.False(t, res) } func TestRenamePerms(t *testing.T) { src := "source" target := "target" sub := "/sub" subTarget := sub + "/target" u := dataprovider.User{} u.Permissions = map[string][]string{} u.Permissions["/"] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermCreateSymlinks, dataprovider.PermDeleteFiles} conn := NewBaseConnection("", ProtocolSFTP, "", "", u) assert.False(t, conn.hasRenamePerms(src, target, nil)) u.Permissions["/"] = []string{dataprovider.PermRename} assert.True(t, conn.hasRenamePerms(src, target, nil)) u.Permissions["/"] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermDeleteFiles, dataprovider.PermDeleteDirs} assert.False(t, conn.hasRenamePerms(src, target, nil)) info := vfs.NewFileInfo(src, true, 0, time.Now(), false) u.Permissions["/"] = []string{dataprovider.PermRenameFiles} assert.False(t, conn.hasRenamePerms(src, target, info)) u.Permissions["/"] = []string{dataprovider.PermRenameDirs} assert.True(t, conn.hasRenamePerms(src, target, info)) u.Permissions["/"] = []string{dataprovider.PermRename} assert.True(t, conn.hasRenamePerms(src, target, info)) u.Permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDeleteDirs} assert.False(t, conn.hasRenamePerms(src, target, info)) // test with different permissions between source and target u.Permissions["/"] = []string{dataprovider.PermRename} u.Permissions[sub] = []string{dataprovider.PermRenameFiles} assert.False(t, conn.hasRenamePerms(src, subTarget, info)) u.Permissions[sub] = []string{dataprovider.PermRenameDirs} assert.True(t, conn.hasRenamePerms(src, subTarget, info)) // test files info = vfs.NewFileInfo(src, false, 0, time.Now(), false) u.Permissions["/"] = []string{dataprovider.PermRenameDirs} assert.False(t, conn.hasRenamePerms(src, target, info)) u.Permissions["/"] = []string{dataprovider.PermRenameFiles} assert.True(t, conn.hasRenamePerms(src, target, info)) u.Permissions["/"] = []string{dataprovider.PermRename} assert.True(t, conn.hasRenamePerms(src, target, info)) // test with different permissions between source and target u.Permissions["/"] = []string{dataprovider.PermRename} u.Permissions[sub] = []string{dataprovider.PermRenameDirs} assert.False(t, conn.hasRenamePerms(src, subTarget, info)) u.Permissions[sub] = []string{dataprovider.PermRenameFiles} assert.True(t, conn.hasRenamePerms(src, subTarget, info)) } func TestRenameNestedFolders(t *testing.T) { u := dataprovider.User{} u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: "vfolder", MappedPath: filepath.Join(os.TempDir(), "f"), }, VirtualPath: "/vdirs/f", }) conn := NewBaseConnection("", ProtocolSFTP, "", "", u) err := conn.checkFolderRename(nil, nil, filepath.Clean(os.TempDir()), filepath.Join(os.TempDir(), "subdir"), "/src", "/dst", nil) assert.Error(t, err) err = conn.checkFolderRename(nil, nil, filepath.Join(os.TempDir(), "subdir"), filepath.Clean(os.TempDir()), "/src", "/dst", nil) assert.Error(t, err) err = conn.checkFolderRename(nil, nil, "", "", "/src/sub", "/src", nil) assert.Error(t, err) err = conn.checkFolderRename(nil, nil, filepath.Join(os.TempDir(), "src"), filepath.Join(os.TempDir(), "vdirs"), "/src", "/vdirs", nil) assert.Error(t, err) } func TestUpdateQuotaAfterRename(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, HomeDir: filepath.Join(os.TempDir(), "home"), }, } mappedPath := filepath.Join(os.TempDir(), "vdir") user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath, }, VirtualPath: "/vdir", QuotaFiles: -1, QuotaSize: -1, }) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath, }, VirtualPath: "/vdir1", QuotaFiles: -1, QuotaSize: -1, }) err := os.MkdirAll(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) fs, err := user.GetFilesystem("id") assert.NoError(t, err) c := NewBaseConnection("", ProtocolSFTP, "", "", user) request := sftp.NewRequest("Rename", "/testfile") if runtime.GOOS != osWindows { request.Filepath = "/dir" request.Target = path.Join("/vdir", "dir") testDirPath := filepath.Join(mappedPath, "dir") err := os.MkdirAll(testDirPath, os.ModePerm) assert.NoError(t, err) err = os.Chmod(testDirPath, 0001) assert.NoError(t, err) err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, testDirPath, 0, -1, -1) assert.Error(t, err) err = os.Chmod(testDirPath, os.ModePerm) assert.NoError(t, err) } testFile1 := "/testfile1" request.Target = testFile1 request.Filepath = path.Join("/vdir", "file") err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 0, -1, -1) assert.Error(t, err) err = os.WriteFile(filepath.Join(mappedPath, "file"), []byte("test content"), os.ModePerm) assert.NoError(t, err) request.Filepath = testFile1 request.Target = path.Join("/vdir", "file") err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12, -1, -1) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "testfile1"), []byte("test content"), os.ModePerm) assert.NoError(t, err) request.Target = testFile1 request.Filepath = path.Join("/vdir", "file") err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12, -1, -1) assert.NoError(t, err) request.Target = path.Join("/vdir1", "file") request.Filepath = path.Join("/vdir", "file") err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12, -1, -1) assert.NoError(t, err) err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12, 1, 100) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestErrorsMapping(t *testing.T) { fs := vfs.NewOsFs("", os.TempDir(), "", nil) conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}}) osErrorsProtocols := []string{ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare, ProtocolDataRetention, ProtocolOIDC, protocolEventAction} for _, protocol := range supportedProtocols { conn.SetProtocol(protocol) err := conn.GetFsError(fs, os.ErrNotExist) if protocol == ProtocolSFTP { assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile) } else if slices.Contains(osErrorsProtocols, protocol) { assert.EqualError(t, err, os.ErrNotExist.Error()) } else { assert.EqualError(t, err, ErrNotExist.Error()) } err = conn.GetFsError(fs, os.ErrPermission) if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxPermissionDenied.Error()) } else { assert.EqualError(t, err, ErrPermissionDenied.Error()) } err = conn.GetFsError(fs, os.ErrClosed) if protocol == ProtocolSFTP { assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) } else { assert.EqualError(t, err, ErrGenericFailure.Error()) } err = conn.GetFsError(fs, ErrPermissionDenied) if protocol == ProtocolSFTP { assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) } else { assert.EqualError(t, err, ErrPermissionDenied.Error()) } err = conn.GetFsError(fs, vfs.ErrVfsUnsupported) if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) } else { assert.EqualError(t, err, ErrOpUnsupported.Error()) } err = conn.GetFsError(fs, vfs.ErrStorageSizeUnavailable) if protocol == ProtocolSFTP { assert.ErrorIs(t, err, sftp.ErrSSHFxOpUnsupported) assert.Contains(t, err.Error(), vfs.ErrStorageSizeUnavailable.Error()) } else { assert.EqualError(t, err, vfs.ErrStorageSizeUnavailable.Error()) } err = conn.GetQuotaExceededError() assert.True(t, conn.IsQuotaExceededError(err)) err = conn.GetReadQuotaExceededError() if protocol == ProtocolSFTP { assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) } else { assert.ErrorIs(t, err, ErrReadQuotaExceeded) } err = conn.GetNotExistError() assert.True(t, conn.IsNotExistError(err)) err = conn.GetFsError(fs, nil) assert.NoError(t, err) err = conn.GetOpUnsupportedError() if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) } else { assert.EqualError(t, err, ErrOpUnsupported.Error()) } err = conn.GetFsError(fs, ErrShuttingDown) if protocol == ProtocolSFTP { assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) assert.Contains(t, err.Error(), ErrShuttingDown.Error()) } else { assert.EqualError(t, err, ErrShuttingDown.Error()) } } } func TestMaxWriteSize(t *testing.T) { permissions := make(map[string][]string) permissions["/"] = []string{dataprovider.PermAny} user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, Permissions: permissions, HomeDir: filepath.Clean(os.TempDir()), }, } fs, err := user.GetFilesystem("123") assert.NoError(t, err) conn := NewBaseConnection("", ProtocolFTP, "", "", user) quotaResult := vfs.QuotaCheckResult{ HasSpace: true, } size, err := conn.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(0), size) conn.User.Filters.MaxUploadFileSize = 100 size, err = conn.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(100), size) quotaResult.QuotaSize = 1000 size, err = conn.GetMaxWriteSize(quotaResult, false, 50, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(100), size) quotaResult.QuotaSize = 1000 quotaResult.UsedSize = 990 size, err = conn.GetMaxWriteSize(quotaResult, false, 50, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(60), size) quotaResult.QuotaSize = 0 quotaResult.UsedSize = 0 size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported()) assert.True(t, conn.IsQuotaExceededError(err)) assert.Equal(t, int64(0), size) size, err = conn.GetMaxWriteSize(quotaResult, true, 10, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(90), size) fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), "", nil) size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported()) assert.EqualError(t, err, ErrOpUnsupported.Error()) assert.Equal(t, int64(0), size) } func TestCheckParentDirsErrors(t *testing.T) { permissions := make(map[string][]string) permissions["/"] = []string{dataprovider.PermAny} user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, Permissions: permissions, HomeDir: filepath.Clean(os.TempDir()), }, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, }, } c := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) err := c.CheckParentDirs("/a/dir") assert.Error(t, err) user.FsConfig.Provider = sdk.LocalFilesystemProvider user.VirtualFolders = nil user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, }, }, VirtualPath: "/vdir", }) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Clean(os.TempDir()), }, VirtualPath: "/vdir/sub", }) c = NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) err = c.CheckParentDirs("/vdir/sub/dir") assert.Error(t, err) user = dataprovider.User{ BaseUser: sdk.BaseUser{ Username: userTestUsername, Permissions: permissions, HomeDir: filepath.Clean(os.TempDir()), }, FsConfig: vfs.Filesystem{ Provider: sdk.S3FilesystemProvider, S3Config: vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: "buck", Region: "us-east-1", AccessKey: "key", }, AccessSecret: kms.NewPlainSecret("s3secret"), }, }, } c = NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) err = c.CheckParentDirs("/a/dir") assert.NoError(t, err) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Clean(os.TempDir()), }, VirtualPath: "/local/dir", }) c = NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) err = c.CheckParentDirs("/local/dir/sub-dir") assert.NoError(t, err) err = os.RemoveAll(filepath.Join(os.TempDir(), "sub-dir")) assert.NoError(t, err) } func TestErrorResolvePath(t *testing.T) { u := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Join(os.TempDir(), "u"), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") u.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: "f", MappedPath: filepath.Join(os.TempDir(), "f"), }, VirtualPath: "/f", }, } conn := NewBaseConnection("", ProtocolSFTP, "", "", u) err := conn.doRecursiveRemoveDirEntry("/vpath", nil, 0) assert.Error(t, err) err = conn.doRecursiveRemove(nil, "/fspath", "/vpath", vfs.NewFileInfo("vpath", true, 0, time.Now(), false), 2000) assert.Error(t, err, util.ErrRecursionTooDeep) err = conn.doRecursiveCopy("/src", "/dst", vfs.NewFileInfo("src", true, 0, time.Now(), false), false, 2000) assert.Error(t, err, util.ErrRecursionTooDeep) err = conn.checkCopy(vfs.NewFileInfo("name", true, 0, time.Unix(0, 0), false), nil, "/source", "/target") assert.Error(t, err) sourceFile := filepath.Join(os.TempDir(), "f", "source") err = os.MkdirAll(filepath.Dir(sourceFile), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(sourceFile, []byte(""), 0666) assert.NoError(t, err) err = conn.checkCopy(vfs.NewFileInfo("name", true, 0, time.Unix(0, 0), false), nil, "/f/source", "/target") assert.Error(t, err) err = conn.checkCopy(vfs.NewFileInfo("source", false, 0, time.Unix(0, 0), false), vfs.NewFileInfo("target", true, 0, time.Unix(0, 0), false), "/f/source", "/f/target") assert.Error(t, err) err = os.RemoveAll(filepath.Dir(sourceFile)) assert.NoError(t, err) } func TestConnectionKeepAlive(t *testing.T) { conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{}) lastActivity := conn.GetLastActivity() stop := keepConnectionAlive(conn, 50*time.Millisecond) defer stop() time.Sleep(200 * time.Millisecond) assert.Greater(t, conn.GetLastActivity(), lastActivity) } func TestFsFileCopier(t *testing.T) { fs := vfs.Fs(&vfs.AzureBlobFs{}) _, ok := fs.(vfs.FsFileCopier) assert.True(t, ok) fs = vfs.Fs(&vfs.OsFs{}) _, ok = fs.(vfs.FsFileCopier) assert.False(t, ok) fs = vfs.Fs(&vfs.SFTPFs{}) _, ok = fs.(vfs.FsFileCopier) assert.False(t, ok) fs = vfs.Fs(&vfs.GCSFs{}) _, ok = fs.(vfs.FsFileCopier) assert.True(t, ok) fs = vfs.Fs(&vfs.S3Fs{}) _, ok = fs.(vfs.FsFileCopier) assert.True(t, ok) } func TestFilePatterns(t *testing.T) { filters := dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ FilePatterns: []sdk.PatternsFilter{ { Path: "/dir1", DenyPolicy: sdk.DenyPolicyDefault, AllowedPatterns: []string{"*.jpg"}, }, { Path: "/dir2", DenyPolicy: sdk.DenyPolicyHide, AllowedPatterns: []string{"*.jpg"}, }, { Path: "/dir3", DenyPolicy: sdk.DenyPolicyDefault, DeniedPatterns: []string{"*.jpg"}, }, { Path: "/dir4", DenyPolicy: sdk.DenyPolicyHide, DeniedPatterns: []string{"*"}, }, }, }, } virtualFolders := []vfs.VirtualFolder{ { VirtualPath: "/dir1/vdir1", }, { VirtualPath: "/dir1/vdir2", }, { VirtualPath: "/dir1/vdir3", }, { VirtualPath: "/dir2/vdir1", }, { VirtualPath: "/dir2/vdir2", }, { VirtualPath: "/dir2/vdir3.jpg", }, } user := dataprovider.User{ Filters: filters, VirtualFolders: virtualFolders, } getFilteredInfo := func(dirContents []os.FileInfo, virtualPath string) []os.FileInfo { result := user.FilterListDir(dirContents, virtualPath) result = append(result, user.GetVirtualFoldersInfo(virtualPath)...) return result } dirContents := []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } // dirContents are modified in place, we need to redefine them each time filtered := getFilteredInfo(dirContents, "/dir1") assert.Len(t, filtered, 5) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir1/vdir1") assert.Len(t, filtered, 2) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir2/vdir2") require.Len(t, filtered, 1) assert.Equal(t, "file1.jpg", filtered[0].Name()) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir2/vdir2/sub") require.Len(t, filtered, 1) assert.Equal(t, "file1.jpg", filtered[0].Name()) res, _ := user.IsFileAllowed("/dir1/vdir1/file.txt") assert.False(t, res) res, _ = user.IsFileAllowed("/dir1/vdir1/sub/file.txt") assert.False(t, res) res, _ = user.IsFileAllowed("/dir1/vdir1/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir1/vdir1/sub/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/file.jpg") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/dir1/file.jpg") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/dir1/sub/file.jpg") assert.False(t, res) res, _ = user.IsFileAllowed("/dir4/file.jpg") assert.False(t, res) res, _ = user.IsFileAllowed("/dir4/dir1/sub/file.jpg") assert.False(t, res) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir4") require.Len(t, filtered, 0) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir4/vdir2/sub") require.Len(t, filtered, 0) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir2") assert.Len(t, filtered, 2) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir4") assert.Len(t, filtered, 0) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir4/sub") assert.Len(t, filtered, 0) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir1") assert.Len(t, filtered, 5) filtered = getFilteredInfo(dirContents, "/dir2") if assert.Len(t, filtered, 1) { assert.True(t, filtered[0].IsDir()) } user.VirtualFolders = nil dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir1") assert.Len(t, filtered, 2) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir2") if assert.Len(t, filtered, 1) { assert.False(t, filtered[0].IsDir()) } dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir2") if assert.Len(t, filtered, 2) { assert.False(t, filtered[0].IsDir()) assert.False(t, filtered[1].IsDir()) } user.VirtualFolders = virtualFolders user.Filters = filters filtered = getFilteredInfo(nil, "/dir1") assert.Len(t, filtered, 3) filtered = getFilteredInfo(nil, "/dir2") assert.Len(t, filtered, 1) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jPg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir2") assert.Len(t, filtered, 2) user = dataprovider.User{ Filters: dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ FilePatterns: []sdk.PatternsFilter{ { Path: "/dir3", AllowedPatterns: []string{"ic35"}, DeniedPatterns: []string{"*"}, DenyPolicy: sdk.DenyPolicyHide, }, }, }, }, } dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3") assert.Len(t, filtered, 0) dirContents = nil for i := 0; i < 100; i++ { dirContents = append(dirContents, vfs.NewFileInfo(fmt.Sprintf("ic%02d", i), i%2 == 0, int64(i), time.Now(), false)) } dirContents = append(dirContents, vfs.NewFileInfo("ic350", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo(".ic35", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo("ic35.", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo("*ic35", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo("ic35*", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo("ic35.*", false, 123, time.Now(), false)) dirContents = append(dirContents, vfs.NewFileInfo("file.jpg", false, 123, time.Now(), false)) filtered = getFilteredInfo(dirContents, "/dir3") require.Len(t, filtered, 1) assert.Equal(t, "ic35", filtered[0].Name()) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic36") require.Len(t, filtered, 0) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic35") require.Len(t, filtered, 3) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub") require.Len(t, filtered, 3) res, _ = user.IsFileAllowed("/dir3/file.txt") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35a") assert.False(t, res) res, policy := user.IsFileAllowed("/dir3/ic35a/file") assert.False(t, res) assert.Equal(t, sdk.DenyPolicyHide, policy) res, _ = user.IsFileAllowed("/dir3/ic35") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/file.txt") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub/file.txt") assert.True(t, res) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub") require.Len(t, filtered, 3) user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ Path: "/dir3/ic35/sub1", AllowedPatterns: []string{"*.jpg"}, DenyPolicy: sdk.DenyPolicyDefault, }) user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ Path: "/dir3/ic35/sub2", DeniedPatterns: []string{"*.jpg"}, DenyPolicy: sdk.DenyPolicyHide, }) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub1") require.Len(t, filtered, 3) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2") require.Len(t, filtered, 2) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2/sub1") require.Len(t, filtered, 2) res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/file.txt") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub/dir/file.txt") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub/dir/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub1/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub1/file.txt") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub1/sub/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub1/sub2/file.txt") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub2/file.jpg") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub2/file.txt") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub2/sub/file.jpg") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub2/sub1/file.txt") assert.True(t, res) user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ Path: "/dir3/ic35", DeniedPatterns: []string{"*.txt"}, DenyPolicy: sdk.DenyPolicyHide, }) res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/file.txt") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/adir/sub/file.jpg") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/adir/file.txt") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub2/file.jpg") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub2/file.txt") assert.True(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub2/sub/file.jpg") assert.False(t, res) res, _ = user.IsFileAllowed("/dir3/ic35/sub2/sub1/file.txt") assert.True(t, res) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic35") require.Len(t, filtered, 1) dirContents = []os.FileInfo{ vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), } filtered = getFilteredInfo(dirContents, "/dir3/ic35/abc") require.Len(t, filtered, 1) } func TestStatForOngoingTransfers(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: xid.New().String(), Password: xid.New().String(), HomeDir: filepath.Clean(os.TempDir()), Status: 1, Permissions: map[string][]string{ "/": {"*"}, }, }, } fileName := "file.txt" conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) fs := vfs.NewOsFs("", os.TempDir(), "", nil) tr := NewBaseTransfer(nil, conn, nil, filepath.Join(os.TempDir(), fileName), filepath.Join(os.TempDir(), fileName), fileName, TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) _, err := conn.DoStat("/file.txt", 0, false) assert.NoError(t, err) err = tr.Close() assert.NoError(t, err) tr = NewBaseTransfer(nil, conn, nil, filepath.Join(os.TempDir(), fileName), filepath.Join(os.TempDir(), fileName), fileName, TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) _, err = conn.DoStat("/file.txt", 0, false) assert.Error(t, err) err = tr.Close() assert.NoError(t, err) err = conn.CloseFS() assert.NoError(t, err) } func TestListerAt(t *testing.T) { dir := t.TempDir() user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "u", Password: "p", HomeDir: dir, Status: 1, Permissions: map[string][]string{ "/": {"*"}, }, }, } conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) lister, err := conn.ListDir("/") require.NoError(t, err) files, err := lister.Next(1) require.ErrorIs(t, err, io.EOF) require.Len(t, files, 0) err = lister.Close() require.NoError(t, err) conn.User.VirtualFolders = []vfs.VirtualFolder{ { VirtualPath: "p1", }, { VirtualPath: "p2", }, { VirtualPath: "p3", }, } lister, err = conn.ListDir("/") require.NoError(t, err) files, err = lister.Next(2) // virtual directories exceeds the limit require.ErrorIs(t, err, io.EOF) require.Len(t, files, 3) files, err = lister.Next(2) require.ErrorIs(t, err, io.EOF) require.Len(t, files, 0) _, err = lister.Next(-1) require.ErrorContains(t, err, conn.GetGenericError(err).Error()) err = lister.Close() require.NoError(t, err) lister, err = conn.ListDir("/") require.NoError(t, err) _, err = lister.ListAt(nil, 0) require.ErrorContains(t, err, "zero size") err = lister.Close() require.NoError(t, err) for i := 0; i < 100; i++ { f, err := os.Create(filepath.Join(dir, strconv.Itoa(i))) require.NoError(t, err) err = f.Close() require.NoError(t, err) } lister, err = conn.ListDir("/") require.NoError(t, err) files = make([]os.FileInfo, 18) n, err := lister.ListAt(files, 0) require.NoError(t, err) require.Equal(t, 18, n) n, err = lister.ListAt(files, 0) require.NoError(t, err) require.Equal(t, 18, n) files = make([]os.FileInfo, 100) n, err = lister.ListAt(files, 0) require.NoError(t, err) require.Equal(t, 64+3, n) n, err = lister.ListAt(files, 0) require.ErrorIs(t, err, io.EOF) require.Equal(t, 0, n) n, err = lister.ListAt(files, 0) require.ErrorIs(t, err, io.EOF) require.Equal(t, 0, n) err = lister.Close() require.NoError(t, err) n, err = lister.ListAt(files, 0) assert.Error(t, err) assert.NotErrorIs(t, err, io.EOF) require.Equal(t, 0, n) lister, err = conn.ListDir("/") require.NoError(t, err) lister.Prepend(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false)) lister.Prepend(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false)) files = make([]os.FileInfo, 1) n, err = lister.ListAt(files, 0) require.NoError(t, err) require.Equal(t, 1, n) assert.Equal(t, ".", files[0].Name()) files = make([]os.FileInfo, 2) n, err = lister.ListAt(files, 0) require.NoError(t, err) require.Equal(t, 2, n) assert.Equal(t, "..", files[0].Name()) vfolders := []string{files[1].Name()} files = make([]os.FileInfo, 200) n, err = lister.ListAt(files, 0) require.NoError(t, err) require.Equal(t, 102, n) vfolders = append(vfolders, files[0].Name()) vfolders = append(vfolders, files[1].Name()) assert.Contains(t, vfolders, "p1") assert.Contains(t, vfolders, "p2") assert.Contains(t, vfolders, "p3") err = lister.Close() require.NoError(t, err) } func TestGetFsAndResolvedPath(t *testing.T) { homeDir := filepath.Join(os.TempDir(), "home_test") localVdir := filepath.Join(os.TempDir(), "local_mount_test") err := os.MkdirAll(homeDir, 0777) require.NoError(t, err) err = os.MkdirAll(localVdir, 0777) require.NoError(t, err) t.Cleanup(func() { os.RemoveAll(homeDir) os.RemoveAll(localVdir) }) user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: xid.New().String(), Status: 1, HomeDir: homeDir, }, VirtualFolders: []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: "s3", MappedPath: "", FsConfig: vfs.Filesystem{ Provider: sdk.S3FilesystemProvider, S3Config: vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: "my-test-bucket", Region: "us-east-1", }, }, }, }, VirtualPath: "/s3", }, { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: "local", MappedPath: localVdir, FsConfig: vfs.Filesystem{ Provider: sdk.LocalFilesystemProvider, }, }, VirtualPath: "/local", }, }, } conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) tests := []struct { name string inputVirtualPath string expectedFsType string expectedPhyPath string // The resolved path on the target FS expectedRelativePath string }{ { name: "Root File", inputVirtualPath: "/file.txt", expectedFsType: "osfs", expectedPhyPath: filepath.Join(homeDir, "file.txt"), expectedRelativePath: "/file.txt", }, { name: "Standard S3 File", inputVirtualPath: "/s3/image.png", expectedFsType: "S3Fs", expectedPhyPath: "image.png", expectedRelativePath: "/s3/image.png", }, { name: "Standard Local Mount File", inputVirtualPath: "/local/config.json", expectedFsType: "osfs", expectedPhyPath: filepath.Join(localVdir, "config.json"), expectedRelativePath: "/local/config.json", }, { name: "Backslash Separator -> Should hit S3", inputVirtualPath: "\\s3\\doc.txt", expectedFsType: "S3Fs", expectedPhyPath: "doc.txt", expectedRelativePath: "/s3/doc.txt", }, { name: "Mixed Separators -> Should hit Local Mount", inputVirtualPath: "/local\\subdir/test.txt", expectedFsType: "osfs", expectedPhyPath: filepath.Join(localVdir, "subdir", "test.txt"), expectedRelativePath: "/local/subdir/test.txt", }, { name: "Double Slash -> Should normalize and hit S3", inputVirtualPath: "//s3//dir @1/data.csv", expectedFsType: "S3Fs", expectedPhyPath: "dir @1/data.csv", expectedRelativePath: "/s3/dir @1/data.csv", }, { name: "Local Mount Traversal (Attempt to escape)", inputVirtualPath: "/local/../../etc/passwd", expectedFsType: "osfs", expectedPhyPath: filepath.Join(homeDir, "/etc/passwd"), expectedRelativePath: "/etc/passwd", }, { name: "Traversal Out of S3 (Valid)", inputVirtualPath: "/s3/../../secret.txt", expectedFsType: "osfs", expectedPhyPath: filepath.Join(homeDir, "secret.txt"), expectedRelativePath: "/secret.txt", }, { name: "Traversal Inside S3", inputVirtualPath: "/s3/subdir/../image.png", expectedFsType: "S3Fs", expectedPhyPath: "image.png", expectedRelativePath: "/s3/image.png", }, { name: "Mount Point Bypass -> Target Local Mount", inputVirtualPath: "/s3\\..\\local\\secret.txt", expectedFsType: "osfs", expectedPhyPath: filepath.Join(localVdir, "secret.txt"), expectedRelativePath: "/local/secret.txt", }, { name: "Dirty Relative Path (Your Case)", inputVirtualPath: "test\\..\\..\\oops/file.txt", expectedFsType: "osfs", expectedPhyPath: filepath.Join(homeDir, "oops", "file.txt"), expectedRelativePath: "/oops/file.txt", }, { name: "Relative Path targeting S3 (No leading slash)", inputVirtualPath: "s3//sub/../image.png", expectedFsType: "S3Fs", expectedPhyPath: "image.png", expectedRelativePath: "/s3/image.png", }, { name: "Windows Path starting with Backslash", inputVirtualPath: "\\s3\\doc/dir\\doc.txt", expectedFsType: "S3Fs", expectedPhyPath: "doc/dir/doc.txt", expectedRelativePath: "/s3/doc/dir/doc.txt", }, { name: "Filesystem Juggling (Relative)", inputVirtualPath: "local/../s3/file.txt", expectedFsType: "S3Fs", expectedPhyPath: "file.txt", expectedRelativePath: "/s3/file.txt", }, { name: "Triple Dot Filename (Valid Name)", inputVirtualPath: "/...hidden/secret", expectedFsType: "osfs", expectedPhyPath: filepath.Join(homeDir, "...hidden", "secret"), expectedRelativePath: "/...hidden/secret", }, { name: "Dot Slash Prefix", inputVirtualPath: "./local/file.txt", expectedFsType: "osfs", expectedPhyPath: filepath.Join(localVdir, "file.txt"), expectedRelativePath: "/local/file.txt", }, { name: "Root of Local Mount Exactly", inputVirtualPath: "/local/", expectedFsType: "osfs", expectedPhyPath: localVdir, expectedRelativePath: "/local", }, { name: "Root of S3 Mount Exactly", inputVirtualPath: "/s3/", expectedFsType: "S3Fs", expectedPhyPath: "", expectedRelativePath: "/s3", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // The input path is sanitized by the protocol handler // implementations before reaching GetFsAndResolvedPath. cleanInput := util.CleanPath(tc.inputVirtualPath) fs, resolvedPath, err := conn.GetFsAndResolvedPath(cleanInput) if assert.NoError(t, err, "did not expect error for path: %q, got: %v", tc.inputVirtualPath, err) { assert.Contains(t, fs.Name(), tc.expectedFsType, "routing error: input %q but expected fs %q, got %q", tc.inputVirtualPath, tc.expectedFsType, fs.Name()) assert.Equal(t, tc.expectedPhyPath, resolvedPath, "resolution error: input %q resolved to %q expected %q", tc.inputVirtualPath, resolvedPath, tc.expectedPhyPath) relativePath := fs.GetRelativePath(resolvedPath) assert.Equal(t, tc.expectedRelativePath, relativePath, "relative path error, input %q, got %q, expected %q", tc.inputVirtualPath, tc.expectedRelativePath, relativePath) } }) } } func TestOsFsGetRelativePath(t *testing.T) { homeDir := filepath.Join(os.TempDir(), "home_test") localVdir := filepath.Join(os.TempDir(), "local_mount_test") err := os.MkdirAll(homeDir, 0777) require.NoError(t, err) err = os.MkdirAll(localVdir, 0777) require.NoError(t, err) t.Cleanup(func() { os.RemoveAll(homeDir) os.RemoveAll(localVdir) }) user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: xid.New().String(), Status: 1, HomeDir: homeDir, }, VirtualFolders: []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: "local", MappedPath: localVdir, FsConfig: vfs.Filesystem{ Provider: sdk.LocalFilesystemProvider, }, }, VirtualPath: "/local", }, }, } connID := xid.New().String() rootFs, err := user.GetFilesystemForPath("/", connID) require.NoError(t, err) localFs, err := user.GetFilesystemForPath("/local", connID) require.NoError(t, err) tests := []struct { name string fs vfs.Fs inputPath string // The physical path to reverse-map expectedRel string // The expected virtual path }{ { name: "Root FS - Inside root", fs: rootFs, inputPath: filepath.Join(homeDir, "docs", "file.txt"), expectedRel: "/docs/file.txt", }, { name: "Root FS - Exact root directory", fs: rootFs, inputPath: homeDir, expectedRel: "/", }, { name: "Root FS - External absolute path (Jail to /)", fs: rootFs, inputPath: "/etc/passwd", expectedRel: "/", }, { name: "Root FS - Traversal escape (Jail to /)", fs: rootFs, inputPath: filepath.Join(homeDir, "..", "escaped.txt"), expectedRel: "/", }, { name: "Root FS - Valid file named with triple dots", fs: rootFs, inputPath: filepath.Join(homeDir, "..."), expectedRel: "/...", }, { name: "Local FS - Up path in dir", fs: rootFs, inputPath: homeDir + "/../" + filepath.Base(homeDir) + "/dir/test.txt", expectedRel: "/dir/test.txt", }, { name: "Local FS - Inside mount", fs: localFs, inputPath: filepath.Join(localVdir, "data", "config.json"), expectedRel: "/local/data/config.json", }, { name: "Local FS - Exact mount directory", fs: localFs, inputPath: localVdir, expectedRel: "/local", }, { name: "Local FS - External absolute path (Jail to /local)", fs: localFs, inputPath: "/var/log/syslog", expectedRel: "/local", }, { name: "Local FS - Traversal escape (Jail to /local)", fs: localFs, inputPath: filepath.Join(localVdir, "..", "..", "etc", "passwd"), expectedRel: "/local", }, { name: "Local FS - Partial prefix (Jail to /local)", fs: localFs, inputPath: localVdir + "_backup", expectedRel: "/local", }, { name: "Local FS - Relative traversal matching virual dir", fs: localFs, inputPath: localVdir + "/../" + filepath.Base(localVdir) + "/dir/test.txt", expectedRel: "/local/dir/test.txt", }, { name: "Local FS - Valid file starting with two dots", fs: localFs, inputPath: filepath.Join(localVdir, "..hidden_file.txt"), expectedRel: "/local/..hidden_file.txt", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { actualRel := tc.fs.GetRelativePath(tc.inputPath) assert.Equal(t, tc.expectedRel, actualRel, "Failed mapping physical path %q on FS %q", tc.inputPath, tc.fs.Name()) }) } } ================================================ FILE: internal/common/dataretention.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "errors" "fmt" "io" "os" "path" "sync" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) var ( // RetentionChecks is the list of active retention checks RetentionChecks ActiveRetentionChecks ) // ActiveRetentionChecks holds the active retention checks type ActiveRetentionChecks struct { sync.RWMutex Checks []RetentionCheck } // Get returns the active retention checks func (c *ActiveRetentionChecks) Get(role string) []RetentionCheck { c.RLock() defer c.RUnlock() checks := make([]RetentionCheck, 0, len(c.Checks)) for _, check := range c.Checks { if role == "" || role == check.Role { foldersCopy := make([]dataprovider.FolderRetention, len(check.Folders)) copy(foldersCopy, check.Folders) checks = append(checks, RetentionCheck{ Username: check.Username, StartTime: check.StartTime, Folders: foldersCopy, }) } } return checks } // Add a new retention check, returns nil if a retention check for the given // username is already active. The returned result can be used to start the check func (c *ActiveRetentionChecks) Add(check RetentionCheck, user *dataprovider.User) *RetentionCheck { c.Lock() defer c.Unlock() for _, val := range c.Checks { if val.Username == user.Username { return nil } } // we silently ignore file patterns user.Filters.FilePatterns = nil conn := NewBaseConnection("", "", "", "", *user) conn.SetProtocol(ProtocolDataRetention) conn.ID = fmt.Sprintf("data_retention_%v", user.Username) check.Username = user.Username check.Role = user.Role check.StartTime = util.GetTimeAsMsSinceEpoch(time.Now()) check.conn = conn check.updateUserPermissions() c.Checks = append(c.Checks, check) return &check } // remove a user from the ones with active retention checks // and returns true if the user is removed func (c *ActiveRetentionChecks) remove(username string) bool { c.Lock() defer c.Unlock() for idx, check := range c.Checks { if check.Username == username { lastIdx := len(c.Checks) - 1 c.Checks[idx] = c.Checks[lastIdx] c.Checks = c.Checks[:lastIdx] return true } } return false } type folderRetentionCheckResult struct { Path string `json:"path"` Retention int `json:"retention"` DeletedFiles int `json:"deleted_files"` DeletedSize int64 `json:"deleted_size"` Elapsed time.Duration `json:"-"` Info string `json:"info,omitempty"` Error string `json:"error,omitempty"` } // RetentionCheck defines an active retention check type RetentionCheck struct { // Username to which the retention check refers Username string `json:"username"` // retention check start time as unix timestamp in milliseconds StartTime int64 `json:"start_time"` // affected folders Folders []dataprovider.FolderRetention `json:"folders"` Role string `json:"-"` // Cleanup results results []folderRetentionCheckResult `json:"-"` conn *BaseConnection `json:"-"` } func (c *RetentionCheck) updateUserPermissions() { for k := range c.conn.User.Permissions { c.conn.User.Permissions[k] = []string{dataprovider.PermAny} } } func (c *RetentionCheck) getFolderRetention(folderPath string) (dataprovider.FolderRetention, error) { dirsForPath := util.GetDirsForVirtualPath(folderPath) for _, dirPath := range dirsForPath { for _, folder := range c.Folders { if folder.Path == dirPath { return folder, nil } } } return dataprovider.FolderRetention{}, fmt.Errorf("unable to find folder retention for %q", folderPath) } func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error { fs, fsPath, err := c.conn.GetFsAndResolvedPath(virtualPath) if err != nil { return err } return c.conn.RemoveFile(fs, fsPath, virtualPath, info) } func (c *RetentionCheck) cleanupFolder(folderPath string, recursion int) error { startTime := time.Now() result := folderRetentionCheckResult{ Path: folderPath, } defer func() { c.results = append(c.results, result) }() if recursion >= util.MaxRecursion { result.Elapsed = time.Since(startTime) result.Info = "data retention check skipped: recursion too deep" c.conn.Log(logger.LevelError, "data retention check skipped, recursion too depth for %q: %d", folderPath, recursion) return util.ErrRecursionTooDeep } recursion++ folderRetention, err := c.getFolderRetention(folderPath) if err != nil { result.Elapsed = time.Since(startTime) result.Error = "unable to get folder retention" c.conn.Log(logger.LevelError, "unable to get folder retention for path %q", folderPath) return err } result.Retention = folderRetention.Retention if folderRetention.Retention == 0 { result.Elapsed = time.Since(startTime) result.Info = "data retention check skipped: retention is set to 0" c.conn.Log(logger.LevelDebug, "retention check skipped for folder %q, retention is set to 0", folderPath) return nil } c.conn.Log(logger.LevelDebug, "start retention check for folder %q, retention: %v hours, delete empty dirs? %v", folderPath, folderRetention.Retention, folderRetention.DeleteEmptyDirs) lister, err := c.conn.ListDir(folderPath) if err != nil { result.Elapsed = time.Since(startTime) if err == c.conn.GetNotExistError() { result.Info = "data retention check skipped, folder does not exist" c.conn.Log(logger.LevelDebug, "folder %q does not exist, retention check skipped", folderPath) return nil } result.Error = fmt.Sprintf("unable to get lister for directory %q", folderPath) c.conn.Log(logger.LevelError, "%s", result.Error) return err } defer lister.Close() for { files, err := lister.Next(vfs.ListerBatchSize) finished := errors.Is(err, io.EOF) if err := lister.convertError(err); err != nil { result.Elapsed = time.Since(startTime) result.Error = fmt.Sprintf("unable to list directory %q", folderPath) c.conn.Log(logger.LevelError, "unable to list dir %q: %v", folderPath, err) return err } for _, info := range files { virtualPath := path.Join(folderPath, info.Name()) if info.IsDir() { if err := c.cleanupFolder(virtualPath, recursion); err != nil { result.Elapsed = time.Since(startTime) result.Error = fmt.Sprintf("unable to check folder: %v", err) c.conn.Log(logger.LevelError, "unable to cleanup folder %q: %v", virtualPath, err) return err } } else { retentionTime := info.ModTime().Add(time.Duration(folderRetention.Retention) * time.Hour) if retentionTime.Before(time.Now()) { if err := c.removeFile(virtualPath, info); err != nil { result.Elapsed = time.Since(startTime) result.Error = fmt.Sprintf("unable to remove file %q: %v", virtualPath, err) c.conn.Log(logger.LevelError, "unable to remove file %q, retention %v: %v", virtualPath, retentionTime, err) return err } c.conn.Log(logger.LevelDebug, "removed file %q, modification time: %v, retention: %v hours, retention time: %v", virtualPath, info.ModTime(), folderRetention.Retention, retentionTime) result.DeletedFiles++ result.DeletedSize += info.Size() } } } if finished { break } } lister.Close() c.checkEmptyDirRemoval(folderPath, folderRetention.DeleteEmptyDirs) result.Elapsed = time.Since(startTime) c.conn.Log(logger.LevelDebug, "retention check completed for folder %q, deleted files: %v, deleted size: %v bytes", folderPath, result.DeletedFiles, result.DeletedSize) return nil } func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string, checkVal bool) { if folderPath == "/" || !checkVal { return } for _, folder := range c.Folders { if folderPath == folder.Path { return } } if c.conn.User.HasAnyPerm([]string{ dataprovider.PermDelete, dataprovider.PermDeleteDirs, }, path.Dir(folderPath), ) { lister, err := c.conn.ListDir(folderPath) if err == nil { files, err := lister.Next(1) lister.Close() if len(files) == 0 && errors.Is(err, io.EOF) { err = c.conn.RemoveDir(folderPath) c.conn.Log(logger.LevelDebug, "tried to remove empty dir %q, error: %v", folderPath, err) } } } } // Start starts the retention check func (c *RetentionCheck) Start() error { c.conn.Log(logger.LevelInfo, "retention check started") defer RetentionChecks.remove(c.conn.User.Username) defer c.conn.CloseFS() //nolint:errcheck startTime := time.Now() for _, folder := range c.Folders { if folder.Retention > 0 { if err := c.cleanupFolder(folder.Path, 0); err != nil { c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %q, elapsed: %s", folder.Path, time.Since(startTime)) return err } } } c.conn.Log(logger.LevelInfo, "retention check completed, elapsed: %s", time.Since(startTime)) return nil } ================================================ FILE: internal/common/dataretention_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "fmt" "testing" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/util" ) func TestRetentionPermissionsAndGetFolder(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user1", }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDelete} user.Permissions["/dir1"] = []string{dataprovider.PermListItems} user.Permissions["/dir2/sub1"] = []string{dataprovider.PermCreateDirs} user.Permissions["/dir2/sub2"] = []string{dataprovider.PermDelete} check := RetentionCheck{ Folders: []dataprovider.FolderRetention{ { Path: "/dir2", Retention: 24 * 7, }, { Path: "/dir3", Retention: 24 * 7, }, { Path: "/dir2/sub1/sub", Retention: 24, }, }, } conn := NewBaseConnection("", "", "", "", user) conn.SetProtocol(ProtocolDataRetention) conn.ID = fmt.Sprintf("data_retention_%v", user.Username) check.conn = conn check.updateUserPermissions() assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/"]) assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir1"]) assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir2/sub1"]) assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir2/sub2"]) _, err := check.getFolderRetention("/") assert.Error(t, err) folder, err := check.getFolderRetention("/dir3") assert.NoError(t, err) assert.Equal(t, "/dir3", folder.Path) folder, err = check.getFolderRetention("/dir2/sub3") assert.NoError(t, err) assert.Equal(t, "/dir2", folder.Path) folder, err = check.getFolderRetention("/dir2/sub2") assert.NoError(t, err) assert.Equal(t, "/dir2", folder.Path) folder, err = check.getFolderRetention("/dir2/sub1") assert.NoError(t, err) assert.Equal(t, "/dir2", folder.Path) folder, err = check.getFolderRetention("/dir2/sub1/sub/sub") assert.NoError(t, err) assert.Equal(t, "/dir2/sub1/sub", folder.Path) } func TestRetentionCheckAddRemove(t *testing.T) { username := "username" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} check := RetentionCheck{ Folders: []dataprovider.FolderRetention{ { Path: "/", Retention: 48, }, }, } assert.NotNil(t, RetentionChecks.Add(check, &user)) checks := RetentionChecks.Get("") require.Len(t, checks, 1) assert.Equal(t, username, checks[0].Username) assert.Greater(t, checks[0].StartTime, int64(0)) require.Len(t, checks[0].Folders, 1) assert.Equal(t, check.Folders[0].Path, checks[0].Folders[0].Path) assert.Equal(t, check.Folders[0].Retention, checks[0].Folders[0].Retention) assert.Nil(t, RetentionChecks.Add(check, &user)) assert.True(t, RetentionChecks.remove(username)) require.Len(t, RetentionChecks.Get(""), 0) assert.False(t, RetentionChecks.remove(username)) } func TestRetentionCheckRole(t *testing.T) { username := "retuser" role1 := "retrole1" role2 := "retrole2" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Role: role1, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} check := RetentionCheck{ Folders: []dataprovider.FolderRetention{ { Path: "/", Retention: 48, }, }, } assert.NotNil(t, RetentionChecks.Add(check, &user)) checks := RetentionChecks.Get("") require.Len(t, checks, 1) assert.Empty(t, checks[0].Role) checks = RetentionChecks.Get(role1) require.Len(t, checks, 1) checks = RetentionChecks.Get(role2) require.Len(t, checks, 0) user.Role = "" assert.Nil(t, RetentionChecks.Add(check, &user)) assert.True(t, RetentionChecks.remove(username)) require.Len(t, RetentionChecks.Get(""), 0) } func TestCleanupErrors(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "u", }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} check := &RetentionCheck{ Folders: []dataprovider.FolderRetention{ { Path: "/path", Retention: 48, }, }, } check = RetentionChecks.Add(*check, &user) require.NotNil(t, check) err := check.removeFile("missing file", nil) assert.Error(t, err) err = check.cleanupFolder("/", 0) assert.Error(t, err) err = check.cleanupFolder("/", 1000) assert.ErrorIs(t, err, util.ErrRecursionTooDeep) assert.True(t, RetentionChecks.remove(user.Username)) } ================================================ FILE: internal/common/defender.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "fmt" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" ) // HostEvent is the enumerable for the supported host events type HostEvent string // Supported host events const ( HostEventLoginFailed HostEvent = "LoginFailed" HostEventUserNotFound HostEvent = "UserNotFound" HostEventNoLoginTried HostEvent = "NoLoginTried" HostEventLimitExceeded HostEvent = "LimitExceeded" ) // Supported defender drivers const ( DefenderDriverMemory = "memory" DefenderDriverProvider = "provider" ) var ( supportedDefenderDrivers = []string{DefenderDriverMemory, DefenderDriverProvider} ) // Defender defines the interface that a defender must implements type Defender interface { GetHosts() ([]dataprovider.DefenderEntry, error) GetHost(ip string) (dataprovider.DefenderEntry, error) AddEvent(ip, protocol string, event HostEvent) bool IsBanned(ip, protocol string) bool IsSafe(ip, protocol string) bool GetBanTime(ip string) (*time.Time, error) GetScore(ip string) (int, error) DeleteHost(ip string) bool DelayLogin(err error) } // DefenderConfig defines the "defender" configuration type DefenderConfig struct { // Set to true to enable the defender Enabled bool `json:"enabled" mapstructure:"enabled"` // Defender implementation to use, we support "memory" and "provider". // Using "provider" as driver you can share the defender events among // multiple SFTPGo instances. For a single instance "memory" provider will // be much faster Driver string `json:"driver" mapstructure:"driver"` // BanTime is the number of minutes that a host is banned BanTime int `json:"ban_time" mapstructure:"ban_time"` // Percentage increase of the ban time if a banned host tries to connect again BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"` // Threshold value for banning a client Threshold int `json:"threshold" mapstructure:"threshold"` // Score for invalid login attempts, eg. non-existent user accounts ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"` // Score for valid login attempts, eg. user accounts that exist ScoreValid int `json:"score_valid" mapstructure:"score_valid"` // Score for limit exceeded events, generated from the rate limiters or for max connections // per-host exceeded ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"` // ScoreNoAuth defines the score for clients disconnected without authentication // attempts ScoreNoAuth int `json:"score_no_auth" mapstructure:"score_no_auth"` // Defines the time window, in minutes, for tracking client errors. // A host is banned if it has exceeded the defined threshold during // the last observation time minutes ObservationTime int `json:"observation_time" mapstructure:"observation_time"` // The number of banned IPs and host scores kept in memory will vary between the // soft and hard limit for the "memory" driver. For the "provider" driver the // soft limit is ignored and the hard limit is used to limit the number of entries // to return when you request for the entire host list from the defender EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"` EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"` // Configuration to impose a delay between login attempts LoginDelay LoginDelay `json:"login_delay" mapstructure:"login_delay"` } // LoginDelay defines the delays to impose between login attempts. type LoginDelay struct { // The number of milliseconds to pause prior to allowing a successful login Success int `json:"success" mapstructure:"success"` // The number of milliseconds to pause prior to reporting a failed login PasswordFailed int `json:"password_failed" mapstructure:"password_failed"` } type baseDefender struct { config *DefenderConfig ipList *dataprovider.IPList } func (d *baseDefender) isBanned(ip, protocol string) bool { isListed, mode, err := d.ipList.IsListed(ip, protocol) if err != nil { return false } if isListed && mode == dataprovider.ListModeDeny { return true } return false } func (d *baseDefender) IsSafe(ip, protocol string) bool { isListed, mode, err := d.ipList.IsListed(ip, protocol) if err == nil && isListed && mode == dataprovider.ListModeAllow { return true } return false } func (d *baseDefender) getScore(event HostEvent) int { var score int switch event { case HostEventLoginFailed: score = d.config.ScoreValid case HostEventLimitExceeded: score = d.config.ScoreLimitExceeded case HostEventUserNotFound: score = d.config.ScoreInvalid case HostEventNoLoginTried: score = d.config.ScoreNoAuth } return score } // logEvent logs a defender event that changes a host's score func (d *baseDefender) logEvent(ip, protocol string, event HostEvent, totalScore int) { // ignore events which do not change the host score eventScore := d.getScore(event) if eventScore == 0 { return } logger.GetLogger().Debug(). Timestamp(). Str("sender", "defender"). Str("client_ip", ip). Str("protocol", protocol). Str("event", string(event)). Int("increase_score_by", eventScore). Int("score", totalScore). Send() } // logBan logs a host's ban due to a too high host score func (d *baseDefender) logBan(ip, protocol string) { logger.GetLogger().Info(). Timestamp(). Str("sender", "defender"). Str("client_ip", ip). Str("protocol", protocol). Str("event", "banned"). Send() } // DelayLogin applies the configured login delay. func (d *baseDefender) DelayLogin(err error) { if err == nil { if d.config.LoginDelay.Success > 0 { time.Sleep(time.Duration(d.config.LoginDelay.Success) * time.Millisecond) } return } if d.config.LoginDelay.PasswordFailed > 0 { time.Sleep(time.Duration(d.config.LoginDelay.PasswordFailed) * time.Millisecond) } } type hostEvent struct { dateTime time.Time score int } type hostScore struct { TotalScore int Events []hostEvent } func (c *DefenderConfig) checkScores() error { if c.ScoreInvalid < 0 { c.ScoreInvalid = 0 } if c.ScoreValid < 0 { c.ScoreValid = 0 } if c.ScoreLimitExceeded < 0 { c.ScoreLimitExceeded = 0 } if c.ScoreNoAuth < 0 { c.ScoreNoAuth = 0 } if c.ScoreInvalid == 0 && c.ScoreValid == 0 && c.ScoreLimitExceeded == 0 && c.ScoreNoAuth == 0 { return fmt.Errorf("invalid defender configuration: all scores are disabled") } return nil } // validate returns an error if the configuration is invalid func (c *DefenderConfig) validate() error { if !c.Enabled { return nil } if err := c.checkScores(); err != nil { return err } if c.ScoreInvalid >= c.Threshold { return fmt.Errorf("score_invalid %d cannot be greater than threshold %d", c.ScoreInvalid, c.Threshold) } if c.ScoreValid >= c.Threshold { return fmt.Errorf("score_valid %d cannot be greater than threshold %d", c.ScoreValid, c.Threshold) } if c.ScoreLimitExceeded >= c.Threshold { return fmt.Errorf("score_limit_exceeded %d cannot be greater than threshold %d", c.ScoreLimitExceeded, c.Threshold) } if c.ScoreNoAuth >= c.Threshold { return fmt.Errorf("score_no_auth %d cannot be greater than threshold %d", c.ScoreNoAuth, c.Threshold) } if c.BanTime <= 0 { return fmt.Errorf("invalid ban_time %v", c.BanTime) } if c.BanTimeIncrement <= 0 { return fmt.Errorf("invalid ban_time_increment %v", c.BanTimeIncrement) } if c.ObservationTime <= 0 { return fmt.Errorf("invalid observation_time %v", c.ObservationTime) } if c.EntriesSoftLimit <= 0 { return fmt.Errorf("invalid entries_soft_limit %v", c.EntriesSoftLimit) } if c.EntriesHardLimit <= c.EntriesSoftLimit { return fmt.Errorf("invalid entries_hard_limit %v must be > %v", c.EntriesHardLimit, c.EntriesSoftLimit) } return nil } ================================================ FILE: internal/common/defender_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "encoding/hex" "fmt" "net" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yl2chen/cidranger" "github.com/drakkan/sftpgo/v2/internal/dataprovider" ) func TestBasicDefender(t *testing.T) { entries := []dataprovider.IPListEntry{ { IPOrNet: "172.16.1.1/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "172.16.1.2/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "10.8.0.0/24", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "192.168.1.1/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "192.168.1.2/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "10.8.9.0/24", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "172.16.1.3/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "172.16.1.4/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "192.168.8.0/24", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "192.168.1.3/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "192.168.1.4/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "192.168.9.0/24", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, } for idx := range entries { e := entries[idx] err := dataprovider.AddIPListEntry(&e, "", "", "") assert.NoError(t, err) } config := &DefenderConfig{ Enabled: true, BanTime: 10, BanTimeIncrement: 2, Threshold: 5, ScoreInvalid: 2, ScoreValid: 1, ScoreNoAuth: 2, ScoreLimitExceeded: 3, ObservationTime: 15, EntriesSoftLimit: 1, EntriesHardLimit: 2, } d, err := newInMemoryDefender(config) assert.NoError(t, err) defender := d.(*memoryDefender) assert.True(t, defender.IsBanned("172.16.1.1", ProtocolSSH)) assert.True(t, defender.IsBanned("192.168.1.1", ProtocolFTP)) assert.False(t, defender.IsBanned("172.16.1.10", ProtocolSSH)) assert.False(t, defender.IsBanned("192.168.1.10", ProtocolSSH)) assert.False(t, defender.IsBanned("10.8.2.3", ProtocolSSH)) assert.False(t, defender.IsBanned("10.9.2.3", ProtocolSSH)) assert.True(t, defender.IsBanned("10.8.0.3", ProtocolSSH)) assert.True(t, defender.IsBanned("10.8.9.3", ProtocolSSH)) assert.False(t, defender.IsBanned("invalid ip", ProtocolSSH)) assert.Equal(t, 0, defender.countBanned()) assert.Equal(t, 0, defender.countHosts()) hosts, err := defender.GetHosts() assert.NoError(t, err) assert.Len(t, hosts, 0) _, err = defender.GetHost("10.8.0.4") assert.Error(t, err) defender.AddEvent("172.16.1.4", ProtocolSSH, HostEventLoginFailed) defender.AddEvent("192.168.1.4", ProtocolSSH, HostEventLoginFailed) defender.AddEvent("192.168.8.4", ProtocolSSH, HostEventUserNotFound) defender.AddEvent("172.16.1.3", ProtocolSSH, HostEventLimitExceeded) defender.AddEvent("192.168.1.3", ProtocolSSH, HostEventLimitExceeded) assert.Equal(t, 0, defender.countHosts()) testIP := "12.34.56.78" defender.AddEvent(testIP, ProtocolSSH, HostEventLoginFailed) assert.Equal(t, 1, defender.countHosts()) assert.Equal(t, 0, defender.countBanned()) score, err := defender.GetScore(testIP) assert.NoError(t, err) assert.Equal(t, 1, score) hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { assert.Equal(t, 1, hosts[0].Score) assert.True(t, hosts[0].BanTime.IsZero()) assert.Empty(t, hosts[0].GetBanTime()) } host, err := defender.GetHost(testIP) assert.NoError(t, err) assert.Equal(t, 1, host.Score) assert.Empty(t, host.GetBanTime()) banTime, err := defender.GetBanTime(testIP) assert.NoError(t, err) assert.Nil(t, banTime) defender.AddEvent(testIP, ProtocolSSH, HostEventLimitExceeded) assert.Equal(t, 1, defender.countHosts()) assert.Equal(t, 0, defender.countBanned()) score, err = defender.GetScore(testIP) assert.NoError(t, err) assert.Equal(t, 4, score) hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { assert.Equal(t, 4, hosts[0].Score) assert.True(t, hosts[0].BanTime.IsZero()) assert.Empty(t, hosts[0].GetBanTime()) } defender.AddEvent(testIP, ProtocolSSH, HostEventUserNotFound) defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried) assert.Equal(t, 0, defender.countHosts()) assert.Equal(t, 1, defender.countBanned()) score, err = defender.GetScore(testIP) assert.NoError(t, err) assert.Equal(t, 0, score) banTime, err = defender.GetBanTime(testIP) assert.NoError(t, err) assert.NotNil(t, banTime) hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { assert.Equal(t, 0, hosts[0].Score) assert.False(t, hosts[0].BanTime.IsZero()) assert.NotEmpty(t, hosts[0].GetBanTime()) assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID()) } host, err = defender.GetHost(testIP) assert.NoError(t, err) assert.Equal(t, 0, host.Score) assert.NotEmpty(t, host.GetBanTime()) // now test cleanup, testIP is already banned testIP1 := "12.34.56.79" testIP2 := "12.34.56.80" testIP3 := "12.34.56.81" defender.AddEvent(testIP1, ProtocolSSH, HostEventNoLoginTried) defender.AddEvent(testIP2, ProtocolSSH, HostEventNoLoginTried) assert.Equal(t, 2, defender.countHosts()) time.Sleep(20 * time.Millisecond) defender.AddEvent(testIP3, ProtocolSSH, HostEventNoLoginTried) assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts()) // testIP1 and testIP2 should be removed assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts()) score, err = defender.GetScore(testIP1) assert.NoError(t, err) assert.Equal(t, 0, score) score, err = defender.GetScore(testIP2) assert.NoError(t, err) assert.Equal(t, 0, score) score, err = defender.GetScore(testIP3) assert.NoError(t, err) assert.Equal(t, 2, score) defender.AddEvent(testIP3, ProtocolSSH, HostEventNoLoginTried) defender.AddEvent(testIP3, ProtocolSSH, HostEventNoLoginTried) // IP3 is now banned banTime, err = defender.GetBanTime(testIP3) assert.NoError(t, err) assert.NotNil(t, banTime) assert.Equal(t, 0, defender.countHosts()) time.Sleep(20 * time.Millisecond) for i := 0; i < 3; i++ { defender.AddEvent(testIP1, ProtocolSSH, HostEventNoLoginTried) } assert.Equal(t, 0, defender.countHosts()) assert.Equal(t, config.EntriesSoftLimit, defender.countBanned()) banTime, err = defender.GetBanTime(testIP) assert.NoError(t, err) assert.Nil(t, banTime) banTime, err = defender.GetBanTime(testIP3) assert.NoError(t, err) assert.Nil(t, banTime) banTime, err = defender.GetBanTime(testIP1) assert.NoError(t, err) assert.NotNil(t, banTime) for i := 0; i < 3; i++ { defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried) time.Sleep(10 * time.Millisecond) defender.AddEvent(testIP3, ProtocolSSH, HostEventNoLoginTried) } assert.Equal(t, 0, defender.countHosts()) assert.Equal(t, defender.config.EntriesSoftLimit, defender.countBanned()) banTime, err = defender.GetBanTime(testIP3) assert.NoError(t, err) if assert.NotNil(t, banTime) { assert.True(t, defender.IsBanned(testIP3, ProtocolFTP)) // ban time should increase newBanTime, err := defender.GetBanTime(testIP3) assert.NoError(t, err) assert.True(t, newBanTime.After(*banTime)) } assert.True(t, defender.DeleteHost(testIP3)) assert.False(t, defender.DeleteHost(testIP3)) for _, e := range entries { err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "") assert.NoError(t, err) } } func TestExpiredHostBans(t *testing.T) { config := &DefenderConfig{ Enabled: true, BanTime: 10, BanTimeIncrement: 2, Threshold: 5, ScoreInvalid: 2, ScoreValid: 1, ScoreLimitExceeded: 3, ObservationTime: 15, EntriesSoftLimit: 1, EntriesHardLimit: 2, } d, err := newInMemoryDefender(config) assert.NoError(t, err) defender := d.(*memoryDefender) testIP := "1.2.3.4" defender.banned[testIP] = time.Now().Add(-24 * time.Hour) // the ban is expired testIP should not be listed res, err := defender.GetHosts() assert.NoError(t, err) assert.Len(t, res, 0) assert.False(t, defender.IsBanned(testIP, ProtocolFTP)) _, err = defender.GetHost(testIP) assert.Error(t, err) _, ok := defender.banned[testIP] assert.True(t, ok) // now add an event for an expired banned ip, it should be removed defender.AddEvent(testIP, ProtocolFTP, HostEventLoginFailed) assert.False(t, defender.IsBanned(testIP, ProtocolFTP)) entry, err := defender.GetHost(testIP) assert.NoError(t, err) assert.Equal(t, testIP, entry.IP) assert.Empty(t, entry.GetBanTime()) assert.Equal(t, 1, entry.Score) res, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, res, 1) { assert.Equal(t, testIP, res[0].IP) assert.Empty(t, res[0].GetBanTime()) assert.Equal(t, 1, res[0].Score) } events := []hostEvent{ { dateTime: time.Now().Add(-24 * time.Hour), score: 2, }, { dateTime: time.Now().Add(-24 * time.Hour), score: 3, }, } hs := hostScore{ Events: events, TotalScore: 5, } defender.hosts[testIP] = hs // the recorded scored are too old res, err = defender.GetHosts() assert.NoError(t, err) assert.Len(t, res, 0) _, err = defender.GetHost(testIP) assert.Error(t, err) _, ok = defender.hosts[testIP] assert.True(t, ok) } func TestDefenderCleanup(t *testing.T) { d := memoryDefender{ baseDefender: baseDefender{ config: &DefenderConfig{ ObservationTime: 1, EntriesSoftLimit: 2, EntriesHardLimit: 3, }, }, banned: make(map[string]time.Time), hosts: make(map[string]hostScore), } d.banned["1.1.1.1"] = time.Now().Add(-24 * time.Hour) d.banned["1.1.1.2"] = time.Now().Add(-24 * time.Hour) d.banned["1.1.1.3"] = time.Now().Add(-24 * time.Hour) d.banned["1.1.1.4"] = time.Now().Add(-24 * time.Hour) d.cleanupBanned() assert.Equal(t, 0, d.countBanned()) d.banned["2.2.2.2"] = time.Now().Add(2 * time.Minute) d.banned["2.2.2.3"] = time.Now().Add(1 * time.Minute) d.banned["2.2.2.4"] = time.Now().Add(3 * time.Minute) d.banned["2.2.2.5"] = time.Now().Add(4 * time.Minute) d.cleanupBanned() assert.Equal(t, d.config.EntriesSoftLimit, d.countBanned()) banTime, err := d.GetBanTime("2.2.2.3") assert.NoError(t, err) assert.Nil(t, banTime) d.hosts["3.3.3.3"] = hostScore{ TotalScore: 0, Events: []hostEvent{ { dateTime: time.Now().Add(-5 * time.Minute), score: 1, }, { dateTime: time.Now().Add(-3 * time.Minute), score: 1, }, { dateTime: time.Now(), score: 1, }, }, } d.hosts["3.3.3.4"] = hostScore{ TotalScore: 1, Events: []hostEvent{ { dateTime: time.Now().Add(-3 * time.Minute), score: 1, }, }, } d.hosts["3.3.3.5"] = hostScore{ TotalScore: 1, Events: []hostEvent{ { dateTime: time.Now().Add(-2 * time.Minute), score: 1, }, }, } d.hosts["3.3.3.6"] = hostScore{ TotalScore: 1, Events: []hostEvent{ { dateTime: time.Now().Add(-1 * time.Minute), score: 1, }, }, } score, err := d.GetScore("3.3.3.3") assert.NoError(t, err) assert.Equal(t, 1, score) d.cleanupHosts() assert.Equal(t, d.config.EntriesSoftLimit, d.countHosts()) score, err = d.GetScore("3.3.3.4") assert.NoError(t, err) assert.Equal(t, 0, score) } func TestDefenderDelay(t *testing.T) { d := memoryDefender{ baseDefender: baseDefender{ config: &DefenderConfig{ ObservationTime: 1, EntriesSoftLimit: 2, EntriesHardLimit: 3, LoginDelay: LoginDelay{ Success: 50, PasswordFailed: 200, }, }, }, } startTime := time.Now() d.DelayLogin(nil) elapsed := time.Since(startTime) assert.Less(t, elapsed, time.Millisecond*100) startTime = time.Now() d.DelayLogin(ErrInternalFailure) elapsed = time.Since(startTime) assert.Greater(t, elapsed, time.Millisecond*150) } func TestDefenderConfig(t *testing.T) { c := DefenderConfig{} err := c.validate() require.NoError(t, err) c.Enabled = true c.Threshold = 10 c.ScoreInvalid = 10 err = c.validate() require.Error(t, err) c.ScoreInvalid = 2 c.ScoreLimitExceeded = 10 err = c.validate() require.Error(t, err) c.ScoreLimitExceeded = 2 c.ScoreValid = 10 err = c.validate() require.Error(t, err) c.ScoreValid = 1 c.ScoreNoAuth = 10 err = c.validate() require.Error(t, err) c.ScoreNoAuth = 2 c.BanTime = 0 err = c.validate() require.Error(t, err) c.BanTime = 30 c.BanTimeIncrement = 0 err = c.validate() require.Error(t, err) c.BanTimeIncrement = 50 c.ObservationTime = 0 err = c.validate() require.Error(t, err) c.ObservationTime = 30 err = c.validate() require.Error(t, err) c.EntriesSoftLimit = 10 err = c.validate() require.Error(t, err) c.EntriesHardLimit = 10 err = c.validate() require.Error(t, err) c.EntriesHardLimit = 20 err = c.validate() require.NoError(t, err) c = DefenderConfig{ Enabled: true, ScoreInvalid: -1, ScoreLimitExceeded: -1, ScoreNoAuth: -1, ScoreValid: -1, } err = c.validate() require.Error(t, err) assert.Equal(t, 0, c.ScoreInvalid) assert.Equal(t, 0, c.ScoreValid) assert.Equal(t, 0, c.ScoreLimitExceeded) assert.Equal(t, 0, c.ScoreNoAuth) } func BenchmarkDefenderBannedSearch(b *testing.B) { d := getDefenderForBench() ip, ipnet, err := net.ParseCIDR("10.8.0.0/12") // 1048574 ip addresses if err != nil { panic(err) } for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { d.banned[ip.String()] = time.Now().Add(10 * time.Minute) } b.ResetTimer() for i := 0; i < b.N; i++ { d.IsBanned("192.168.1.1", ProtocolSSH) } } func BenchmarkCleanup(b *testing.B) { d := getDefenderForBench() ip, ipnet, err := net.ParseCIDR("192.168.4.0/24") if err != nil { panic(err) } b.ResetTimer() for i := 0; i < b.N; i++ { for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { d.AddEvent(ip.String(), ProtocolSSH, HostEventLoginFailed) if d.countHosts() > d.config.EntriesHardLimit { panic("too many hosts") } if d.countBanned() > d.config.EntriesSoftLimit { panic("too many ip banned") } } } } func BenchmarkCIDRanger(b *testing.B) { ranger := cidranger.NewPCTrieRanger() for i := 0; i < 255; i++ { cidr := fmt.Sprintf("192.168.%d.1/24", i) _, network, _ := net.ParseCIDR(cidr) if err := ranger.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil { panic(err) } } ipToMatch := net.ParseIP("192.167.1.2") b.ResetTimer() for i := 0; i < b.N; i++ { if _, err := ranger.Contains(ipToMatch); err != nil { panic(err) } } } func BenchmarkNetContains(b *testing.B) { var nets []*net.IPNet for i := 0; i < 255; i++ { cidr := fmt.Sprintf("192.168.%d.1/24", i) _, network, _ := net.ParseCIDR(cidr) nets = append(nets, network) } ipToMatch := net.ParseIP("192.167.1.1") b.ResetTimer() for i := 0; i < b.N; i++ { for _, n := range nets { n.Contains(ipToMatch) } } } func getDefenderForBench() *memoryDefender { config := &DefenderConfig{ Enabled: true, BanTime: 30, BanTimeIncrement: 50, Threshold: 10, ScoreInvalid: 2, ScoreValid: 2, ObservationTime: 30, EntriesSoftLimit: 50, EntriesHardLimit: 100, } return &memoryDefender{ baseDefender: baseDefender{ config: config, }, hosts: make(map[string]hostScore), banned: make(map[string]time.Time), } } func inc(ip net.IP) { for j := len(ip) - 1; j >= 0; j-- { ip[j]++ if ip[j] > 0 { break } } } ================================================ FILE: internal/common/defenderdb.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "sync/atomic" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) type dbDefender struct { baseDefender lastCleanup atomic.Int64 } func newDBDefender(config *DefenderConfig) (Defender, error) { err := config.validate() if err != nil { return nil, err } ipList, err := dataprovider.NewIPList(dataprovider.IPListTypeDefender) if err != nil { return nil, err } defender := &dbDefender{ baseDefender: baseDefender{ config: config, ipList: ipList, }, } defender.lastCleanup.Store(0) return defender, nil } // GetHosts returns hosts that are banned or for which some violations have been detected func (d *dbDefender) GetHosts() ([]dataprovider.DefenderEntry, error) { return dataprovider.GetDefenderHosts(d.getStartObservationTime(), d.config.EntriesHardLimit) } // GetHost returns a defender host by ip, if any func (d *dbDefender) GetHost(ip string) (dataprovider.DefenderEntry, error) { return dataprovider.GetDefenderHostByIP(ip, d.getStartObservationTime()) } // IsBanned returns true if the specified IP is banned // and increase ban time if the IP is found. // This method must be called as soon as the client connects func (d *dbDefender) IsBanned(ip, protocol string) bool { if d.isBanned(ip, protocol) { return true } _, err := dataprovider.IsDefenderHostBanned(ip) if err != nil { // not found or another error, we allow this host return false } increment := d.config.BanTime * d.config.BanTimeIncrement / 100 if increment == 0 { increment++ } dataprovider.UpdateDefenderBanTime(ip, increment) //nolint:errcheck return true } // DeleteHost removes the specified IP from the defender lists func (d *dbDefender) DeleteHost(ip string) bool { if _, err := d.GetHost(ip); err != nil { return false } return dataprovider.DeleteDefenderHost(ip) == nil } // AddEvent adds an event for the given IP. // This method must be called for clients not yet banned. // Returns true if the IP is in the defender's safe list. func (d *dbDefender) AddEvent(ip, protocol string, event HostEvent) bool { if d.IsSafe(ip, protocol) { return true } score := d.getScore(event) host, err := dataprovider.AddDefenderEvent(ip, score, d.getStartObservationTime()) if err != nil { return false } d.logEvent(ip, protocol, event, host.Score) if host.Score > d.config.Threshold { d.logBan(ip, protocol) banTime := time.Now().Add(time.Duration(d.config.BanTime) * time.Minute) err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(banTime)) if err == nil { eventManager.handleIPBlockedEvent(EventParams{ Event: ipBlockedEventName, IP: ip, Timestamp: time.Now(), Status: 1, }) } } if err == nil { d.cleanup() } return false } // GetBanTime returns the ban time for the given IP or nil if the IP is not banned func (d *dbDefender) GetBanTime(ip string) (*time.Time, error) { host, err := d.GetHost(ip) if err != nil { return nil, err } if host.BanTime.IsZero() { return nil, nil } return &host.BanTime, nil } // GetScore returns the score for the given IP func (d *dbDefender) GetScore(ip string) (int, error) { host, err := d.GetHost(ip) if err != nil { return 0, err } return host.Score, nil } func (d *dbDefender) cleanup() { lastCleanup := d.getLastCleanup() if lastCleanup.IsZero() || lastCleanup.Add(time.Duration(d.config.ObservationTime)*time.Minute*3).Before(time.Now()) { // FIXME: this could be racy in rare cases but it is better than acquire the lock for the cleanup duration // or to always acquire a read/write lock. // Concurrent cleanups could happen anyway from multiple SFTPGo instances and should not cause any issues d.setLastCleanup(time.Now()) expireTime := time.Now().Add(-time.Duration(d.config.ObservationTime+1) * time.Minute) logger.Debug(logSender, "", "cleanup defender hosts before %v, last cleanup %v", expireTime, lastCleanup) if err := dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(expireTime)); err != nil { logger.Error(logSender, "", "defender cleanup error, reset last cleanup to %v", lastCleanup) d.setLastCleanup(lastCleanup) } } } func (d *dbDefender) getStartObservationTime() int64 { t := time.Now().Add(-time.Duration(d.config.ObservationTime) * time.Minute) return util.GetTimeAsMsSinceEpoch(t) } func (d *dbDefender) getLastCleanup() time.Time { val := d.lastCleanup.Load() if val == 0 { return time.Time{} } return util.GetTimeFromMsecSinceEpoch(val) } func (d *dbDefender) setLastCleanup(when time.Time) { if when.IsZero() { d.lastCleanup.Store(0) return } d.lastCleanup.Store(util.GetTimeAsMsSinceEpoch(when)) } ================================================ FILE: internal/common/defenderdb_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "encoding/hex" "testing" "time" "github.com/stretchr/testify/assert" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/util" ) func TestBasicDbDefender(t *testing.T) { if !isDbDefenderSupported() { t.Skip("this test is not supported with the current database provider") } entries := []dataprovider.IPListEntry{ { IPOrNet: "172.16.1.1/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "172.16.1.2/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "10.8.0.0/24", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, }, { IPOrNet: "172.16.1.3/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "172.16.1.4/32", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, { IPOrNet: "192.168.8.0/24", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeAllow, }, } for idx := range entries { e := entries[idx] err := dataprovider.AddIPListEntry(&e, "", "", "") assert.NoError(t, err) } config := &DefenderConfig{ Enabled: true, BanTime: 10, BanTimeIncrement: 2, Threshold: 5, ScoreInvalid: 2, ScoreValid: 1, ScoreNoAuth: 2, ScoreLimitExceeded: 3, ObservationTime: 15, EntriesSoftLimit: 1, EntriesHardLimit: 10, } d, err := newDBDefender(config) assert.NoError(t, err) defender := d.(*dbDefender) assert.True(t, defender.IsBanned("172.16.1.1", ProtocolFTP)) assert.False(t, defender.IsBanned("172.16.1.10", ProtocolSSH)) assert.False(t, defender.IsBanned("10.8.1.3", ProtocolHTTP)) assert.True(t, defender.IsBanned("10.8.0.4", ProtocolWebDAV)) assert.False(t, defender.IsBanned("invalid ip", ProtocolSSH)) hosts, err := defender.GetHosts() assert.NoError(t, err) assert.Len(t, hosts, 0) _, err = defender.GetHost("10.8.0.3") assert.Error(t, err) defender.AddEvent("172.16.1.4", ProtocolSSH, HostEventLoginFailed) defender.AddEvent("192.168.8.4", ProtocolSSH, HostEventUserNotFound) defender.AddEvent("172.16.1.3", ProtocolSSH, HostEventLimitExceeded) hosts, err = defender.GetHosts() assert.NoError(t, err) assert.Len(t, hosts, 0) assert.True(t, defender.getLastCleanup().IsZero()) testIP := "123.45.67.89" defender.AddEvent(testIP, ProtocolSSH, HostEventLoginFailed) lastCleanup := defender.getLastCleanup() assert.False(t, lastCleanup.IsZero()) score, err := defender.GetScore(testIP) assert.NoError(t, err) assert.Equal(t, 1, score) hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { assert.Equal(t, 1, hosts[0].Score) assert.True(t, hosts[0].BanTime.IsZero()) assert.Empty(t, hosts[0].GetBanTime()) } host, err := defender.GetHost(testIP) assert.NoError(t, err) assert.Equal(t, 1, host.Score) assert.Empty(t, host.GetBanTime()) banTime, err := defender.GetBanTime(testIP) assert.NoError(t, err) assert.Nil(t, banTime) defender.AddEvent(testIP, ProtocolSSH, HostEventLimitExceeded) score, err = defender.GetScore(testIP) assert.NoError(t, err) assert.Equal(t, 4, score) hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { assert.Equal(t, 4, hosts[0].Score) assert.True(t, hosts[0].BanTime.IsZero()) assert.Empty(t, hosts[0].GetBanTime()) } defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried) defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried) score, err = defender.GetScore(testIP) assert.NoError(t, err) assert.Equal(t, 0, score) banTime, err = defender.GetBanTime(testIP) assert.NoError(t, err) assert.NotNil(t, banTime) hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { assert.Equal(t, 0, hosts[0].Score) assert.False(t, hosts[0].BanTime.IsZero()) assert.NotEmpty(t, hosts[0].GetBanTime()) assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID()) } host, err = defender.GetHost(testIP) assert.NoError(t, err) assert.Equal(t, 0, host.Score) assert.NotEmpty(t, host.GetBanTime()) // ban time should increase assert.True(t, defender.IsBanned(testIP, ProtocolSSH)) newBanTime, err := defender.GetBanTime(testIP) assert.NoError(t, err) assert.True(t, newBanTime.After(*banTime)) assert.True(t, defender.DeleteHost(testIP)) assert.False(t, defender.DeleteHost(testIP)) // test cleanup testIP1 := "123.45.67.90" testIP2 := "123.45.67.91" testIP3 := "123.45.67.92" for i := 0; i < 3; i++ { defender.AddEvent(testIP, ProtocolSSH, HostEventUserNotFound) defender.AddEvent(testIP1, ProtocolSSH, HostEventNoLoginTried) defender.AddEvent(testIP2, ProtocolSSH, HostEventUserNotFound) } hosts, err = defender.GetHosts() assert.NoError(t, err) assert.Len(t, hosts, 3) for _, host := range hosts { assert.Equal(t, 0, host.Score) assert.False(t, host.BanTime.IsZero()) assert.NotEmpty(t, host.GetBanTime()) } defender.AddEvent(testIP3, ProtocolSSH, HostEventLoginFailed) hosts, err = defender.GetHosts() assert.NoError(t, err) assert.Len(t, hosts, 4) // now set a ban time in the past, so the host will be cleanead up for _, ip := range []string{testIP1, testIP2} { err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute))) assert.NoError(t, err) } hosts, err = defender.GetHosts() assert.NoError(t, err) assert.Len(t, hosts, 4) for _, host := range hosts { switch host.IP { case testIP: assert.Equal(t, 0, host.Score) assert.False(t, host.BanTime.IsZero()) assert.NotEmpty(t, host.GetBanTime()) case testIP3: assert.Equal(t, 1, host.Score) assert.True(t, host.BanTime.IsZero()) assert.Empty(t, host.GetBanTime()) default: assert.Equal(t, 6, host.Score) assert.True(t, host.BanTime.IsZero()) assert.Empty(t, host.GetBanTime()) } } host, err = defender.GetHost(testIP) assert.NoError(t, err) assert.Equal(t, 0, host.Score) assert.False(t, host.BanTime.IsZero()) assert.NotEmpty(t, host.GetBanTime()) host, err = defender.GetHost(testIP3) assert.NoError(t, err) assert.Equal(t, 1, host.Score) assert.True(t, host.BanTime.IsZero()) assert.Empty(t, host.GetBanTime()) // set a negative observation time so the from field in the queries will be in the future // we still should get the banned hosts defender.config.ObservationTime = -2 assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli()) hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { assert.Equal(t, testIP, hosts[0].IP) assert.Equal(t, 0, hosts[0].Score) assert.False(t, hosts[0].BanTime.IsZero()) assert.NotEmpty(t, hosts[0].GetBanTime()) } _, err = defender.GetHost(testIP) assert.NoError(t, err) // cleanup db err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) assert.NoError(t, err) // the banned host must still be there hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { assert.Equal(t, testIP, hosts[0].IP) assert.Equal(t, 0, hosts[0].Score) assert.False(t, hosts[0].BanTime.IsZero()) assert.NotEmpty(t, hosts[0].GetBanTime()) } _, err = defender.GetHost(testIP) assert.NoError(t, err) err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute))) assert.NoError(t, err) err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) assert.NoError(t, err) hosts, err = defender.GetHosts() assert.NoError(t, err) assert.Len(t, hosts, 0) for _, e := range entries { err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "") assert.NoError(t, err) } } func TestDbDefenderCleanup(t *testing.T) { if !isDbDefenderSupported() { t.Skip("this test is not supported with the current database provider") } config := &DefenderConfig{ Enabled: true, BanTime: 10, BanTimeIncrement: 2, Threshold: 5, ScoreInvalid: 2, ScoreValid: 1, ScoreLimitExceeded: 3, ObservationTime: 15, EntriesSoftLimit: 1, EntriesHardLimit: 10, } d, err := newDBDefender(config) assert.NoError(t, err) defender := d.(*dbDefender) lastCleanup := defender.getLastCleanup() assert.True(t, lastCleanup.IsZero()) defender.cleanup() lastCleanup = defender.getLastCleanup() assert.False(t, lastCleanup.IsZero()) defender.cleanup() assert.Equal(t, lastCleanup, defender.getLastCleanup()) defender.setLastCleanup(time.Time{}) assert.True(t, defender.getLastCleanup().IsZero()) defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4)) time.Sleep(20 * time.Millisecond) defender.cleanup() assert.True(t, lastCleanup.Before(defender.getLastCleanup())) providerConf := dataprovider.GetProviderConfig() err = dataprovider.Close() assert.NoError(t, err) lastCleanup = util.GetTimeFromMsecSinceEpoch(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4).UnixMilli()) defender.setLastCleanup(lastCleanup) defender.cleanup() // cleanup will fail and so last cleanup should be reset to the previous value assert.Equal(t, lastCleanup, defender.getLastCleanup()) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func isDbDefenderSupported() bool { // SQLite shares the implementation with other SQL-based provider but it makes no sense // to use it outside test cases switch dataprovider.GetProviderStatus().Driver { case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: return true default: return false } } ================================================ FILE: internal/common/defendermem.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "sort" "sync" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/util" ) type memoryDefender struct { baseDefender sync.RWMutex // IP addresses of the clients trying to connected are stored inside hosts, // they are added to banned once the thresold is reached. // A violation from a banned host will increase the ban time // based on the configured BanTimeIncrement hosts map[string]hostScore // the key is the host IP banned map[string]time.Time // the key is the host IP } func newInMemoryDefender(config *DefenderConfig) (Defender, error) { err := config.validate() if err != nil { return nil, err } ipList, err := dataprovider.NewIPList(dataprovider.IPListTypeDefender) if err != nil { return nil, err } defender := &memoryDefender{ baseDefender: baseDefender{ config: config, ipList: ipList, }, hosts: make(map[string]hostScore), banned: make(map[string]time.Time), } return defender, nil } // GetHosts returns hosts that are banned or for which some violations have been detected func (d *memoryDefender) GetHosts() ([]dataprovider.DefenderEntry, error) { d.RLock() defer d.RUnlock() var result []dataprovider.DefenderEntry for k, v := range d.banned { if v.After(time.Now()) { result = append(result, dataprovider.DefenderEntry{ IP: k, BanTime: v, }) } } for k, v := range d.hosts { score := 0 for _, event := range v.Events { if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { score += event.score } } if score > 0 { result = append(result, dataprovider.DefenderEntry{ IP: k, Score: score, }) } } return result, nil } // GetHost returns a defender host by ip, if any func (d *memoryDefender) GetHost(ip string) (dataprovider.DefenderEntry, error) { d.RLock() defer d.RUnlock() if banTime, ok := d.banned[ip]; ok { if banTime.After(time.Now()) { return dataprovider.DefenderEntry{ IP: ip, BanTime: banTime, }, nil } } if hs, ok := d.hosts[ip]; ok { score := 0 for _, event := range hs.Events { if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { score += event.score } } if score > 0 { return dataprovider.DefenderEntry{ IP: ip, Score: score, }, nil } } return dataprovider.DefenderEntry{}, util.NewRecordNotFoundError("host not found") } // IsBanned returns true if the specified IP is banned // and increase ban time if the IP is found. // This method must be called as soon as the client connects func (d *memoryDefender) IsBanned(ip, protocol string) bool { d.RLock() if banTime, ok := d.banned[ip]; ok { if banTime.After(time.Now()) { increment := d.config.BanTime * d.config.BanTimeIncrement / 100 if increment == 0 { increment++ } d.RUnlock() // we can save an earlier ban time if there are contemporary updates // but this should not make much difference. I prefer to hold a read lock // until possible for performance reasons, this method is called each // time a new client connects and it must be as fast as possible d.Lock() d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute) d.Unlock() return true } } defer d.RUnlock() return d.isBanned(ip, protocol) } // DeleteHost removes the specified IP from the defender lists func (d *memoryDefender) DeleteHost(ip string) bool { d.Lock() defer d.Unlock() if _, ok := d.banned[ip]; ok { delete(d.banned, ip) return true } if _, ok := d.hosts[ip]; ok { delete(d.hosts, ip) return true } return false } // AddEvent adds an event for the given IP. // This method must be called for clients not yet banned. // Returns true if the IP is in the defender's safe list. func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) bool { if d.IsSafe(ip, protocol) { return true } d.Lock() defer d.Unlock() // ignore events for already banned hosts if v, ok := d.banned[ip]; ok { if v.After(time.Now()) { return false } delete(d.banned, ip) } score := d.getScore(event) ev := hostEvent{ dateTime: time.Now(), score: score, } if hs, ok := d.hosts[ip]; ok { hs.Events = append(hs.Events, ev) hs.TotalScore = 0 idx := 0 for _, event := range hs.Events { if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { hs.Events[idx] = event hs.TotalScore += event.score idx++ } } d.logEvent(ip, protocol, event, hs.TotalScore) hs.Events = hs.Events[:idx] if hs.TotalScore >= d.config.Threshold { d.logBan(ip, protocol) d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute) delete(d.hosts, ip) d.cleanupBanned() eventManager.handleIPBlockedEvent(EventParams{ Event: ipBlockedEventName, IP: ip, Timestamp: time.Now(), Status: 1, }) } else { d.hosts[ip] = hs } } else { d.logEvent(ip, protocol, event, ev.score) d.hosts[ip] = hostScore{ TotalScore: ev.score, Events: []hostEvent{ev}, } d.cleanupHosts() } return false } func (d *memoryDefender) countBanned() int { d.RLock() defer d.RUnlock() return len(d.banned) } func (d *memoryDefender) countHosts() int { d.RLock() defer d.RUnlock() return len(d.hosts) } // GetBanTime returns the ban time for the given IP or nil if the IP is not banned func (d *memoryDefender) GetBanTime(ip string) (*time.Time, error) { d.RLock() defer d.RUnlock() if banTime, ok := d.banned[ip]; ok { return &banTime, nil } return nil, nil } // GetScore returns the score for the given IP func (d *memoryDefender) GetScore(ip string) (int, error) { d.RLock() defer d.RUnlock() score := 0 if hs, ok := d.hosts[ip]; ok { for _, event := range hs.Events { if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { score += event.score } } } return score, nil } func (d *memoryDefender) cleanupBanned() { if len(d.banned) > d.config.EntriesHardLimit { kvList := make(kvList, 0, len(d.banned)) for k, v := range d.banned { if v.Before(time.Now()) { delete(d.banned, k) } kvList = append(kvList, kv{ Key: k, Value: v.UnixNano(), }) } // we removed expired ip addresses, if any, above, this could be enough numToRemove := len(d.banned) - d.config.EntriesSoftLimit if numToRemove <= 0 { return } sort.Sort(kvList) for idx, kv := range kvList { if idx >= numToRemove { break } delete(d.banned, kv.Key) } } } func (d *memoryDefender) cleanupHosts() { if len(d.hosts) > d.config.EntriesHardLimit { kvList := make(kvList, 0, len(d.hosts)) for k, v := range d.hosts { value := int64(0) if len(v.Events) > 0 { value = v.Events[len(v.Events)-1].dateTime.UnixNano() } kvList = append(kvList, kv{ Key: k, Value: value, }) } sort.Sort(kvList) numToRemove := len(d.hosts) - d.config.EntriesSoftLimit for idx, kv := range kvList { if idx >= numToRemove { break } delete(d.hosts, kv.Key) } } } type kv struct { Key string Value int64 } type kvList []kv func (p kvList) Len() int { return len(p) } func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value } func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] } ================================================ FILE: internal/common/eventmanager.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "bytes" "context" "encoding/csv" "encoding/json" "errors" "fmt" "html" "io" "mime" "mime/multipart" "net/http" "net/textproto" "net/url" "os" "os/exec" "path" "path/filepath" "slices" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/bmatcuk/doublestar/v4" "github.com/klauspost/compress/zip" "github.com/robfig/cron/v3" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/wneessen/go-mail" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( ipBlockedEventName = "IP Blocked" maxAttachmentsSize = int64(10 * 1024 * 1024) objDataPlaceholder = "{{.ObjectData}}" objDataPlaceholderString = "{{.ObjectDataString}}" dateTimeMillisFormat = "2006-01-02T15:04:05.000" ) // Supported IDP login events const ( IDPLoginUser = "IDP login user" IDPLoginAdmin = "IDP login admin" ) var ( // eventManager handle the supported event rules actions eventManager eventRulesContainer multipartQuoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") fsEventsWithSize = []string{operationPreDelete, OperationPreUpload, operationDelete, operationCopy, operationDownload, operationFirstUpload, operationFirstDownload, operationUpload} ) func init() { eventManager = eventRulesContainer{ schedulesMapping: make(map[string][]cron.EntryID), // arbitrary maximum number of concurrent asynchronous tasks, // each task could execute multiple actions concurrencyGuard: make(chan struct{}, 200), } dataprovider.SetEventRulesCallbacks(eventManager.loadRules, eventManager.RemoveRule, func(operation, executor, ip, objectType, objectName, role string, object plugin.Renderer) { p := EventParams{ Name: executor, ObjectName: objectName, Event: operation, Status: 1, ObjectType: objectType, IP: ip, Role: role, Timestamp: time.Now(), Object: object, } if u, ok := object.(*dataprovider.User); ok { p.Email = u.Email p.Groups = u.Groups } else if a, ok := object.(*dataprovider.Admin); ok { p.Email = a.Email } eventManager.handleProviderEvent(p) }) } // HandleCertificateEvent checks and executes action rules for certificate events func HandleCertificateEvent(params EventParams) { eventManager.handleCertificateEvent(params) } // HandleIDPLoginEvent executes actions defined for a successful login from an Identity Provider func HandleIDPLoginEvent(params EventParams, customFields *map[string]any) (*dataprovider.User, *dataprovider.Admin, error) { return eventManager.handleIDPLoginEvent(params, customFields) } // eventRulesContainer stores event rules by trigger type eventRulesContainer struct { sync.RWMutex lastLoad atomic.Int64 FsEvents []dataprovider.EventRule ProviderEvents []dataprovider.EventRule Schedules []dataprovider.EventRule IPBlockedEvents []dataprovider.EventRule CertificateEvents []dataprovider.EventRule IPDLoginEvents []dataprovider.EventRule schedulesMapping map[string][]cron.EntryID concurrencyGuard chan struct{} } func (r *eventRulesContainer) addAsyncTask() { activeHooks.Add(1) r.concurrencyGuard <- struct{}{} } func (r *eventRulesContainer) removeAsyncTask() { activeHooks.Add(-1) <-r.concurrencyGuard } func (r *eventRulesContainer) getLastLoadTime() int64 { return r.lastLoad.Load() } func (r *eventRulesContainer) setLastLoadTime(modTime int64) { r.lastLoad.Store(modTime) } // RemoveRule deletes the rule with the specified name func (r *eventRulesContainer) RemoveRule(name string) { r.Lock() defer r.Unlock() r.removeRuleInternal(name) eventManagerLog(logger.LevelDebug, "event rules updated after delete, fs events: %d, provider events: %d, schedules: %d", len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules)) } func (r *eventRulesContainer) removeRuleInternal(name string) { for idx := range r.FsEvents { if r.FsEvents[idx].Name == name { lastIdx := len(r.FsEvents) - 1 r.FsEvents[idx] = r.FsEvents[lastIdx] r.FsEvents = r.FsEvents[:lastIdx] eventManagerLog(logger.LevelDebug, "removed rule %q from fs events", name) return } } for idx := range r.ProviderEvents { if r.ProviderEvents[idx].Name == name { lastIdx := len(r.ProviderEvents) - 1 r.ProviderEvents[idx] = r.ProviderEvents[lastIdx] r.ProviderEvents = r.ProviderEvents[:lastIdx] eventManagerLog(logger.LevelDebug, "removed rule %q from provider events", name) return } } for idx := range r.IPBlockedEvents { if r.IPBlockedEvents[idx].Name == name { lastIdx := len(r.IPBlockedEvents) - 1 r.IPBlockedEvents[idx] = r.IPBlockedEvents[lastIdx] r.IPBlockedEvents = r.IPBlockedEvents[:lastIdx] eventManagerLog(logger.LevelDebug, "removed rule %q from IP blocked events", name) return } } for idx := range r.CertificateEvents { if r.CertificateEvents[idx].Name == name { lastIdx := len(r.CertificateEvents) - 1 r.CertificateEvents[idx] = r.CertificateEvents[lastIdx] r.CertificateEvents = r.CertificateEvents[:lastIdx] eventManagerLog(logger.LevelDebug, "removed rule %q from certificate events", name) return } } for idx := range r.IPDLoginEvents { if r.IPDLoginEvents[idx].Name == name { lastIdx := len(r.IPDLoginEvents) - 1 r.IPDLoginEvents[idx] = r.IPDLoginEvents[lastIdx] r.IPDLoginEvents = r.IPDLoginEvents[:lastIdx] eventManagerLog(logger.LevelDebug, "removed rule %q from IDP login events", name) return } } for idx := range r.Schedules { if r.Schedules[idx].Name == name { if schedules, ok := r.schedulesMapping[name]; ok { for _, entryID := range schedules { eventManagerLog(logger.LevelDebug, "removing scheduled entry id %d for rule %q", entryID, name) eventScheduler.Remove(entryID) } delete(r.schedulesMapping, name) } lastIdx := len(r.Schedules) - 1 r.Schedules[idx] = r.Schedules[lastIdx] r.Schedules = r.Schedules[:lastIdx] eventManagerLog(logger.LevelDebug, "removed rule %q from scheduled events", name) return } } } func (r *eventRulesContainer) addUpdateRuleInternal(rule dataprovider.EventRule) { r.removeRuleInternal(rule.Name) if rule.DeletedAt > 0 { deletedAt := util.GetTimeFromMsecSinceEpoch(rule.DeletedAt) if deletedAt.Add(30 * time.Minute).Before(time.Now()) { eventManagerLog(logger.LevelDebug, "removing rule %q deleted at %s", rule.Name, deletedAt) go dataprovider.RemoveEventRule(rule) //nolint:errcheck } return } if rule.Status != 1 || rule.Trigger == dataprovider.EventTriggerOnDemand { return } switch rule.Trigger { case dataprovider.EventTriggerFsEvent: r.FsEvents = append(r.FsEvents, rule) eventManagerLog(logger.LevelDebug, "added rule %q to fs events", rule.Name) case dataprovider.EventTriggerProviderEvent: r.ProviderEvents = append(r.ProviderEvents, rule) eventManagerLog(logger.LevelDebug, "added rule %q to provider events", rule.Name) case dataprovider.EventTriggerIPBlocked: r.IPBlockedEvents = append(r.IPBlockedEvents, rule) eventManagerLog(logger.LevelDebug, "added rule %q to IP blocked events", rule.Name) case dataprovider.EventTriggerCertificate: r.CertificateEvents = append(r.CertificateEvents, rule) eventManagerLog(logger.LevelDebug, "added rule %q to certificate events", rule.Name) case dataprovider.EventTriggerIDPLogin: r.IPDLoginEvents = append(r.IPDLoginEvents, rule) eventManagerLog(logger.LevelDebug, "added rule %q to IDP login events", rule.Name) case dataprovider.EventTriggerSchedule: for _, schedule := range rule.Conditions.Schedules { cronSpec := schedule.GetCronSpec() job := &eventCronJob{ ruleName: dataprovider.ConvertName(rule.Name), } entryID, err := eventScheduler.AddJob(cronSpec, job) if err != nil { eventManagerLog(logger.LevelError, "unable to add scheduled rule %q, cron string %q: %v", rule.Name, cronSpec, err) return } r.schedulesMapping[rule.Name] = append(r.schedulesMapping[rule.Name], entryID) eventManagerLog(logger.LevelDebug, "schedule for rule %q added, id: %d, cron string %q, active scheduling rules: %d", rule.Name, entryID, cronSpec, len(r.schedulesMapping)) } r.Schedules = append(r.Schedules, rule) eventManagerLog(logger.LevelDebug, "added rule %q to scheduled events", rule.Name) default: eventManagerLog(logger.LevelError, "unsupported trigger: %d", rule.Trigger) } } func (r *eventRulesContainer) loadRules() { eventManagerLog(logger.LevelDebug, "loading updated rules") modTime := util.GetTimeAsMsSinceEpoch(time.Now()) lastLoadTime := r.getLastLoadTime() rules, err := dataprovider.GetRecentlyUpdatedRules(lastLoadTime) if err != nil { eventManagerLog(logger.LevelError, "unable to load event rules: %v", err) return } eventManagerLog(logger.LevelDebug, "recently updated event rules loaded: %d", len(rules)) if len(rules) > 0 { r.Lock() defer r.Unlock() for _, rule := range rules { r.addUpdateRuleInternal(rule) } } eventManagerLog(logger.LevelDebug, "event rules updated, fs events: %d, provider events: %d, schedules: %d, ip blocked events: %d, certificate events: %d, IDP login events: %d", len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules), len(r.IPBlockedEvents), len(r.CertificateEvents), len(r.IPDLoginEvents)) r.setLastLoadTime(modTime) } func (*eventRulesContainer) checkIPDLoginEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool { switch conditions.IDPLoginEvent { case dataprovider.IDPLoginUser: if params.Event != IDPLoginUser { return false } case dataprovider.IDPLoginAdmin: if params.Event != IDPLoginAdmin { return false } } return checkEventConditionPatterns(params.Name, conditions.Options.Names) } func (*eventRulesContainer) checkProviderEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool { if !slices.Contains(conditions.ProviderEvents, params.Event) { return false } if !checkEventConditionPatterns(params.Name, conditions.Options.Names) { return false } if !checkEventGroupConditionPatterns(params.Groups, conditions.Options.GroupNames) { return false } if !checkEventConditionPatterns(params.Role, conditions.Options.RoleNames) { return false } if len(conditions.Options.ProviderObjects) > 0 && !slices.Contains(conditions.Options.ProviderObjects, params.ObjectType) { return false } return true } func (*eventRulesContainer) checkFsEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool { if !slices.Contains(conditions.FsEvents, params.Event) { return false } if !checkEventConditionPatterns(params.Name, conditions.Options.Names) { return false } if !checkEventConditionPatterns(params.Role, conditions.Options.RoleNames) { return false } if !checkEventGroupConditionPatterns(params.Groups, conditions.Options.GroupNames) { return false } if !checkEventConditionPatterns(params.VirtualPath, conditions.Options.FsPaths) { return false } if len(conditions.Options.Protocols) > 0 && !slices.Contains(conditions.Options.Protocols, params.Protocol) { return false } if slices.Contains(fsEventsWithSize, params.Event) { if conditions.Options.MinFileSize > 0 { if params.FileSize < conditions.Options.MinFileSize { return false } } if conditions.Options.MaxFileSize > 0 { if params.FileSize > conditions.Options.MaxFileSize { return false } } } return true } // hasFsRules returns true if there are any rules for filesystem event triggers func (r *eventRulesContainer) hasFsRules() bool { r.RLock() defer r.RUnlock() return len(r.FsEvents) > 0 } // handleFsEvent executes the rules actions defined for the specified event. // The boolean parameter indicates whether a sync action was executed func (r *eventRulesContainer) handleFsEvent(params EventParams) (bool, error) { if params.Protocol == protocolEventAction { return false, nil } r.RLock() var rulesWithSyncActions, rulesAsync []dataprovider.EventRule for _, rule := range r.FsEvents { if r.checkFsEventMatch(&rule.Conditions, ¶ms) { if err := rule.CheckActionsConsistency(""); err != nil { eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q", rule.Name, err, params.Event) continue } hasSyncActions := false for _, action := range rule.Actions { if action.Options.ExecuteSync { hasSyncActions = true break } } if hasSyncActions { rulesWithSyncActions = append(rulesWithSyncActions, rule) } else { rulesAsync = append(rulesAsync, rule) } } } r.RUnlock() params.sender = params.Name params.addUID() if len(rulesAsync) > 0 { go executeAsyncRulesActions(rulesAsync, params) } if len(rulesWithSyncActions) > 0 { return true, executeSyncRulesActions(rulesWithSyncActions, params) } return false, nil } func (r *eventRulesContainer) handleIDPLoginEvent(params EventParams, customFields *map[string]any) (*dataprovider.User, *dataprovider.Admin, error, ) { r.RLock() var rulesWithSyncActions, rulesAsync []dataprovider.EventRule for _, rule := range r.IPDLoginEvents { if r.checkIPDLoginEventMatch(&rule.Conditions, ¶ms) { if err := rule.CheckActionsConsistency(""); err != nil { eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q", rule.Name, err, params.Event) continue } hasSyncActions := false for _, action := range rule.Actions { if action.Options.ExecuteSync { hasSyncActions = true break } } if hasSyncActions { rulesWithSyncActions = append(rulesWithSyncActions, rule) } else { rulesAsync = append(rulesAsync, rule) } } } r.RUnlock() if len(rulesAsync) == 0 && len(rulesWithSyncActions) == 0 { return nil, nil, nil } params.addIDPCustomFields(customFields) if len(rulesWithSyncActions) > 1 { var ruleNames []string for _, r := range rulesWithSyncActions { ruleNames = append(ruleNames, r.Name) } return nil, nil, fmt.Errorf("more than one account check action rules matches: %q", strings.Join(ruleNames, ",")) } params.addUID() if len(rulesAsync) > 0 { go executeAsyncRulesActions(rulesAsync, params) } if len(rulesWithSyncActions) > 0 { return executeIDPAccountCheckRule(rulesWithSyncActions[0], params) } return nil, nil, nil } // username is populated for user objects func (r *eventRulesContainer) handleProviderEvent(params EventParams) { r.RLock() defer r.RUnlock() var rules []dataprovider.EventRule for _, rule := range r.ProviderEvents { if r.checkProviderEventMatch(&rule.Conditions, ¶ms) { if err := rule.CheckActionsConsistency(params.ObjectType); err == nil { rules = append(rules, rule) } else { eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q object type %q", rule.Name, err, params.Event, params.ObjectType) } } } if len(rules) > 0 { params.sender = params.ObjectName go executeAsyncRulesActions(rules, params) } } func (r *eventRulesContainer) handleIPBlockedEvent(params EventParams) { r.RLock() defer r.RUnlock() if len(r.IPBlockedEvents) == 0 { return } var rules []dataprovider.EventRule for _, rule := range r.IPBlockedEvents { if err := rule.CheckActionsConsistency(""); err == nil { rules = append(rules, rule) } else { eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q", rule.Name, err, params.Event) } } if len(rules) > 0 { go executeAsyncRulesActions(rules, params) } } func (r *eventRulesContainer) handleCertificateEvent(params EventParams) { r.RLock() defer r.RUnlock() if len(r.CertificateEvents) == 0 { return } var rules []dataprovider.EventRule for _, rule := range r.CertificateEvents { if err := rule.CheckActionsConsistency(""); err == nil { rules = append(rules, rule) } else { eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q", rule.Name, err, params.Event) } } if len(rules) > 0 { go executeAsyncRulesActions(rules, params) } } type executedRetentionCheck struct { Username string ActionName string Results []folderRetentionCheckResult } // EventParams defines the supported event parameters type EventParams struct { Name string Groups []sdk.GroupMapping Event string Status int VirtualPath string FsPath string VirtualTargetPath string FsTargetPath string ObjectName string Extension string ObjectType string FileSize int64 Elapsed int64 Protocol string IP string Role string Email string Timestamp time.Time UID string IDPCustomFields *map[string]string Object plugin.Renderer Metadata map[string]string sender string updateStatusFromError bool errors []string retentionChecks []executedRetentionCheck } func (p *EventParams) getACopy() *EventParams { params := *p params.errors = make([]string, len(p.errors)) copy(params.errors, p.errors) retentionChecks := make([]executedRetentionCheck, 0, len(p.retentionChecks)) for _, c := range p.retentionChecks { executedCheck := executedRetentionCheck{ Username: c.Username, ActionName: c.ActionName, } executedCheck.Results = make([]folderRetentionCheckResult, len(c.Results)) copy(executedCheck.Results, c.Results) retentionChecks = append(retentionChecks, executedCheck) } params.retentionChecks = retentionChecks if p.IDPCustomFields != nil { fields := make(map[string]string) for k, v := range *p.IDPCustomFields { fields[k] = v } params.IDPCustomFields = &fields } if len(params.Metadata) > 0 { metadata := make(map[string]string) for k, v := range p.Metadata { metadata[k] = v } params.Metadata = metadata } return ¶ms } func (p *EventParams) addIDPCustomFields(customFields *map[string]any) { if customFields == nil || len(*customFields) == 0 { return } fields := make(map[string]string) for k, v := range *customFields { switch val := v.(type) { case string: fields[k] = val } } p.IDPCustomFields = &fields } // AddError adds a new error to the event params and update the status if needed func (p *EventParams) AddError(err error) { if err == nil { return } if p.updateStatusFromError && p.Status == 1 { p.Status = 2 } p.errors = append(p.errors, err.Error()) } func (p *EventParams) addUID() { if p.UID == "" { p.UID = util.GenerateUniqueID() } } func (p *EventParams) setBackupParams(backupPath string) { if p.sender != "" { return } p.sender = dataprovider.ActionExecutorSystem p.FsPath = backupPath p.ObjectName = filepath.Base(backupPath) p.VirtualPath = "/" + p.ObjectName p.Timestamp = time.Now() info, err := os.Stat(backupPath) if err == nil { p.FileSize = info.Size() } } func (p *EventParams) getStatusString() string { switch p.Status { case 1: return "OK" default: return "KO" } } // getUsers returns users with group settings not applied func (p *EventParams) getUsers() ([]dataprovider.User, error) { if p.sender == "" { dump, err := dataprovider.DumpData([]string{dataprovider.DumpScopeUsers}) if err != nil { eventManagerLog(logger.LevelError, "unable to get users: %+v", err) return nil, errors.New("unable to get users") } return dump.Users, nil } user, err := p.getUserFromSender() if err != nil { return nil, err } return []dataprovider.User{user}, nil } func (p *EventParams) getUserFromSender() (dataprovider.User, error) { if p.sender == dataprovider.ActionExecutorSystem { return dataprovider.User{ BaseUser: sdk.BaseUser{ Status: 1, Username: p.sender, HomeDir: dataprovider.GetBackupsPath(), Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, }, nil } user, err := dataprovider.UserExists(p.sender, "") if err != nil { eventManagerLog(logger.LevelError, "unable to get user %q: %+v", p.sender, err) return user, fmt.Errorf("error getting user %q", p.sender) } return user, nil } func (p *EventParams) getFolders() ([]vfs.BaseVirtualFolder, error) { if p.sender == "" { dump, err := dataprovider.DumpData([]string{dataprovider.DumpScopeFolders}) return dump.Folders, err } folder, err := dataprovider.GetFolderByName(p.sender) if err != nil { return nil, fmt.Errorf("error getting folder %q: %w", p.sender, err) } return []vfs.BaseVirtualFolder{folder}, nil } func (p *EventParams) getCompressedDataRetentionReport() ([]byte, error) { if len(p.retentionChecks) == 0 { return nil, errors.New("no data retention report available") } var b bytes.Buffer if _, err := p.writeCompressedDataRetentionReports(&b); err != nil { return nil, err } return b.Bytes(), nil } func (p *EventParams) writeCompressedDataRetentionReports(w io.Writer) (int64, error) { var n int64 wr := zip.NewWriter(w) for _, check := range p.retentionChecks { data, err := getCSVRetentionReport(check.Results) if err != nil { return n, fmt.Errorf("unable to get CSV report: %w", err) } dataSize := int64(len(data)) n += dataSize // we suppose a 3:1 compression ratio if n > (maxAttachmentsSize * 3) { eventManagerLog(logger.LevelError, "unable to get retention report, size too large: %s", util.ByteCountIEC(n)) return n, fmt.Errorf("unable to get retention report, size too large: %s", util.ByteCountIEC(n)) } fh := &zip.FileHeader{ Name: fmt.Sprintf("%s-%s.csv", check.ActionName, check.Username), Method: zip.Deflate, Modified: time.Now().UTC(), } f, err := wr.CreateHeader(fh) if err != nil { return n, fmt.Errorf("unable to create zip header for file %q: %w", fh.Name, err) } _, err = io.CopyN(f, bytes.NewBuffer(data), dataSize) if err != nil { return n, fmt.Errorf("unable to write content to zip file %q: %w", fh.Name, err) } } if err := wr.Close(); err != nil { return n, fmt.Errorf("unable to close zip writer: %w", err) } return n, nil } func (p *EventParams) getRetentionReportsAsMailAttachment() (*mail.File, error) { if len(p.retentionChecks) == 0 { return nil, errors.New("no data retention report available") } return &mail.File{ Name: "retention-reports.zip", Header: make(map[string][]string), Writer: p.writeCompressedDataRetentionReports, }, nil } func (*EventParams) getStringReplacement(val string, escapeMode int) string { switch escapeMode { case 1: return util.JSONEscape(val) case 2: return html.EscapeString(val) default: return val } } func (p *EventParams) getStringReplacements(addObjectData bool, escapeMode int) []string { var dateTimeString string if Config.TZ == "local" { dateTimeString = p.Timestamp.Local().Format(dateTimeMillisFormat) } else { dateTimeString = p.Timestamp.UTC().Format(dateTimeMillisFormat) } year := dateTimeString[0:4] month := dateTimeString[5:7] day := dateTimeString[8:10] hour := dateTimeString[11:13] minute := dateTimeString[14:16] replacements := []string{ "{{.Name}}", p.getStringReplacement(p.Name, escapeMode), "{{.Event}}", p.Event, "{{.Status}}", fmt.Sprintf("%d", p.Status), "{{.VirtualPath}}", p.getStringReplacement(p.VirtualPath, escapeMode), "{{.EscapedVirtualPath}}", p.getStringReplacement(url.QueryEscape(p.VirtualPath), escapeMode), "{{.FsPath}}", p.getStringReplacement(p.FsPath, escapeMode), "{{.VirtualTargetPath}}", p.getStringReplacement(p.VirtualTargetPath, escapeMode), "{{.FsTargetPath}}", p.getStringReplacement(p.FsTargetPath, escapeMode), "{{.ObjectName}}", p.getStringReplacement(p.ObjectName, escapeMode), "{{.ObjectBaseName}}", p.getStringReplacement(strings.TrimSuffix(p.ObjectName, p.Extension), escapeMode), "{{.ObjectType}}", p.ObjectType, "{{.FileSize}}", strconv.FormatInt(p.FileSize, 10), "{{.Elapsed}}", strconv.FormatInt(p.Elapsed, 10), "{{.Protocol}}", p.Protocol, "{{.IP}}", p.IP, "{{.Role}}", p.getStringReplacement(p.Role, escapeMode), "{{.Email}}", p.getStringReplacement(p.Email, escapeMode), "{{.Timestamp}}", strconv.FormatInt(p.Timestamp.UnixNano(), 10), "{{.DateTime}}", dateTimeString, "{{.Year}}", year, "{{.Month}}", month, "{{.Day}}", day, "{{.Hour}}", hour, "{{.Minute}}", minute, "{{.StatusString}}", p.getStatusString(), "{{.UID}}", p.getStringReplacement(p.UID, escapeMode), "{{.Ext}}", p.getStringReplacement(p.Extension, escapeMode), } if p.VirtualPath != "" { replacements = append(replacements, "{{.VirtualDirPath}}", p.getStringReplacement(path.Dir(p.VirtualPath), escapeMode)) } if p.VirtualTargetPath != "" { replacements = append(replacements, "{{.VirtualTargetDirPath}}", p.getStringReplacement(path.Dir(p.VirtualTargetPath), escapeMode)) replacements = append(replacements, "{{.TargetName}}", p.getStringReplacement(path.Base(p.VirtualTargetPath), escapeMode)) } if len(p.errors) > 0 { replacements = append(replacements, "{{.ErrorString}}", p.getStringReplacement(strings.Join(p.errors, ", "), escapeMode)) } else { replacements = append(replacements, "{{.ErrorString}}", "") } replacements = append(replacements, objDataPlaceholder, "{}") replacements = append(replacements, objDataPlaceholderString, "") if addObjectData { data, err := p.Object.RenderAsJSON(p.Event != operationDelete) if err == nil { dataString := util.BytesToString(data) replacements[len(replacements)-3] = p.getStringReplacement(dataString, 0) replacements[len(replacements)-1] = p.getStringReplacement(dataString, 1) } } if p.IDPCustomFields != nil { for k, v := range *p.IDPCustomFields { replacements = append(replacements, fmt.Sprintf("{{.IDPField%s}}", k), p.getStringReplacement(v, escapeMode)) } } replacements = append(replacements, "{{.Metadata}}", "{}") replacements = append(replacements, "{{.MetadataString}}", "") if len(p.Metadata) > 0 { data, err := json.Marshal(p.Metadata) if err == nil { dataString := util.BytesToString(data) replacements[len(replacements)-3] = p.getStringReplacement(dataString, 0) replacements[len(replacements)-1] = p.getStringReplacement(dataString, 1) } } return replacements } func getCSVRetentionReport(results []folderRetentionCheckResult) ([]byte, error) { var b bytes.Buffer csvWriter := csv.NewWriter(&b) err := csvWriter.Write([]string{"path", "retention (hours)", "deleted files", "deleted size (bytes)", "elapsed (ms)", "info", "error"}) if err != nil { return nil, err } for _, result := range results { err = csvWriter.Write([]string{result.Path, strconv.Itoa(result.Retention), strconv.Itoa(result.DeletedFiles), strconv.FormatInt(result.DeletedSize, 10), strconv.FormatInt(result.Elapsed.Milliseconds(), 10), result.Info, result.Error}) if err != nil { return nil, err } } csvWriter.Flush() err = csvWriter.Error() return b.Bytes(), err } func closeWriterAndUpdateQuota(w io.WriteCloser, conn *BaseConnection, virtualSourcePath, virtualTargetPath string, numFiles int, truncatedSize int64, errTransfer error, operation string, startTime time.Time, ) error { var fsDstPath string var errDstFs error errWrite := w.Close() targetPath := virtualSourcePath if virtualTargetPath != "" { targetPath = virtualTargetPath var fsDst vfs.Fs fsDst, fsDstPath, errDstFs = conn.GetFsAndResolvedPath(virtualTargetPath) if errTransfer != nil && errDstFs == nil { // try to remove a partial file on error. If this fails, we can't do anything errRemove := fsDst.Remove(fsDstPath, false) conn.Log(logger.LevelDebug, "removing partial file %q after write error, result: %v", virtualTargetPath, errRemove) } } info, err := conn.doStatInternal(targetPath, 0, false, false) if err == nil { updateUserQuotaAfterFileWrite(conn, targetPath, numFiles, info.Size()-truncatedSize) var fsSrcPath string var errSrcFs error if virtualSourcePath != "" { _, fsSrcPath, errSrcFs = conn.GetFsAndResolvedPath(virtualSourcePath) } if errSrcFs == nil && errDstFs == nil { elapsed := time.Since(startTime).Nanoseconds() / 1000000 if errTransfer == nil { errTransfer = errWrite } if operation == operationCopy { logger.CommandLog(copyLogSender, fsSrcPath, fsDstPath, conn.User.Username, "", conn.ID, conn.protocol, -1, -1, "", "", "", info.Size(), conn.localAddr, conn.remoteAddr, elapsed) } ExecuteActionNotification(conn, operation, fsSrcPath, virtualSourcePath, fsDstPath, virtualTargetPath, "", info.Size(), errTransfer, elapsed, nil) //nolint:errcheck } } else { eventManagerLog(logger.LevelWarn, "unable to update quota after writing %q: %v", targetPath, err) } if errTransfer != nil { return errTransfer } return errWrite } func updateUserQuotaAfterFileWrite(conn *BaseConnection, virtualPath string, numFiles int, fileSize int64) { vfolder, err := conn.User.GetVirtualFolderForPath(path.Dir(virtualPath)) if err != nil { dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck return } dataprovider.UpdateUserFolderQuota(&vfolder, &conn.User, numFiles, fileSize, false) } func checkWriterPermsAndQuota(conn *BaseConnection, virtualPath string, numFiles int, expectedSize, truncatedSize int64) error { if numFiles == 0 { if !conn.User.HasPerm(dataprovider.PermOverwrite, path.Dir(virtualPath)) { return conn.GetPermissionDeniedError() } } else { if !conn.User.HasPerm(dataprovider.PermUpload, path.Dir(virtualPath)) { return conn.GetPermissionDeniedError() } } q, _ := conn.HasSpace(numFiles > 0, false, virtualPath) if !q.HasSpace { return conn.GetQuotaExceededError() } if expectedSize != -1 { sizeDiff := expectedSize - truncatedSize if sizeDiff > 0 { remainingSize := q.GetRemainingSize() if remainingSize > 0 && remainingSize < sizeDiff { return conn.GetQuotaExceededError() } } } return nil } func getFileWriter(conn *BaseConnection, virtualPath string, expectedSize int64) (io.WriteCloser, int, int64, func(), error) { fs, fsPath, err := conn.GetFsAndResolvedPath(virtualPath) if err != nil { return nil, 0, 0, nil, err } var truncatedSize, fileSize int64 numFiles := 1 isFileOverwrite := false info, err := fs.Lstat(fsPath) if err == nil { fileSize = info.Size() if info.IsDir() { return nil, numFiles, truncatedSize, nil, fmt.Errorf("cannot write to a directory: %q", virtualPath) } if info.Mode().IsRegular() { isFileOverwrite = true truncatedSize = fileSize } numFiles = 0 } if err != nil && !fs.IsNotExist(err) { return nil, numFiles, truncatedSize, nil, conn.GetFsError(fs, err) } if err := checkWriterPermsAndQuota(conn, virtualPath, numFiles, expectedSize, truncatedSize); err != nil { return nil, numFiles, truncatedSize, nil, err } f, w, cancelFn, err := fs.Create(fsPath, 0, conn.GetCreateChecks(virtualPath, numFiles == 1, false)) if err != nil { return nil, numFiles, truncatedSize, nil, conn.GetFsError(fs, err) } vfs.SetPathPermissions(fs, fsPath, conn.User.GetUID(), conn.User.GetGID()) if isFileOverwrite { if vfs.HasTruncateSupport(fs) || vfs.IsCryptOsFs(fs) { updateUserQuotaAfterFileWrite(conn, virtualPath, numFiles, -fileSize) truncatedSize = 0 } } if cancelFn == nil { cancelFn = func() {} } if f != nil { return f, numFiles, truncatedSize, cancelFn, nil } return w, numFiles, truncatedSize, cancelFn, nil } func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir string, info os.FileInfo, recursion int) error { //nolint:gocyclo if entryPath == wr.Name { // skip the archive itself return nil } if recursion >= util.MaxRecursion { eventManagerLog(logger.LevelError, "unable to add zip entry %q, recursion too deep: %v", entryPath, recursion) return util.ErrRecursionTooDeep } recursion++ var err error if info == nil { info, err = conn.DoStat(entryPath, 1, false) if err != nil { eventManagerLog(logger.LevelError, "unable to add zip entry %q, stat error: %v", entryPath, err) return err } } entryName, err := getZipEntryName(entryPath, baseDir) if err != nil { eventManagerLog(logger.LevelError, "unable to get zip entry name: %v", err) return err } if _, ok := wr.Entries[entryName]; ok { eventManagerLog(logger.LevelInfo, "skipping duplicate zip entry %q, is dir %t", entryPath, info.IsDir()) return nil } wr.Entries[entryName] = true if info.IsDir() { _, err = wr.Writer.CreateHeader(&zip.FileHeader{ Name: entryName + "/", Method: zip.Deflate, Modified: info.ModTime(), }) if err != nil { eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) return fmt.Errorf("unable to create zip entry %q: %w", entryPath, err) } lister, err := conn.ListDir(entryPath) if err != nil { eventManagerLog(logger.LevelError, "unable to add zip entry %q, get dir lister error: %v", entryPath, err) return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err) } defer lister.Close() for { contents, err := lister.Next(vfs.ListerBatchSize) finished := errors.Is(err, io.EOF) if err := lister.convertError(err); err != nil { eventManagerLog(logger.LevelError, "unable to add zip entry %q, read dir error: %v", entryPath, err) return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err) } for _, info := range contents { fullPath := util.CleanPath(path.Join(entryPath, info.Name())) if err := addZipEntry(wr, conn, fullPath, baseDir, info, recursion); err != nil { eventManagerLog(logger.LevelError, "unable to add zip entry: %v", err) return err } } if finished { return nil } } } if !info.Mode().IsRegular() { // we only allow regular files eventManagerLog(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath) return nil } return addFileToZip(wr, conn, entryPath, entryName, info.ModTime()) } func addFileToZip(wr *zipWriterWrapper, conn *BaseConnection, entryPath, entryName string, modTime time.Time) error { reader, cancelFn, err := getFileReader(conn, entryPath) if err != nil { eventManagerLog(logger.LevelError, "unable to add zip entry %q, cannot open file: %v", entryPath, err) return fmt.Errorf("unable to open %q: %w", entryPath, err) } defer cancelFn() defer reader.Close() f, err := wr.Writer.CreateHeader(&zip.FileHeader{ Name: entryName, Method: zip.Deflate, Modified: modTime, }) if err != nil { eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) return fmt.Errorf("unable to create zip entry %q: %w", entryPath, err) } _, err = io.Copy(f, reader) return err } func getZipEntryName(entryPath, baseDir string) (string, error) { if !strings.HasPrefix(entryPath, baseDir) { return "", fmt.Errorf("entry path %q is outside base dir %q", entryPath, baseDir) } entryPath = strings.TrimPrefix(entryPath, baseDir) return strings.TrimPrefix(entryPath, "/"), nil } func getFileReader(conn *BaseConnection, virtualPath string) (io.ReadCloser, func(), error) { if !conn.User.HasPerm(dataprovider.PermDownload, path.Dir(virtualPath)) { return nil, nil, conn.GetPermissionDeniedError() } fs, fsPath, err := conn.GetFsAndResolvedPath(virtualPath) if err != nil { return nil, nil, err } f, r, cancelFn, err := fs.Open(fsPath, 0) if err != nil { return nil, nil, conn.GetFsError(fs, err) } if cancelFn == nil { cancelFn = func() {} } if f != nil { return f, cancelFn, nil } return r, cancelFn, nil } func writeFileContent(conn *BaseConnection, virtualPath string, w io.Writer) error { reader, cancelFn, err := getFileReader(conn, virtualPath) if err != nil { return err } defer cancelFn() defer reader.Close() _, err = io.Copy(w, reader) return err } func getFileContentFn(conn *BaseConnection, virtualPath string, size int64) func(w io.Writer) (int64, error) { return func(w io.Writer) (int64, error) { reader, cancelFn, err := getFileReader(conn, virtualPath) if err != nil { return 0, err } defer cancelFn() defer reader.Close() return io.CopyN(w, reader, size) } } func getMailAttachments(conn *BaseConnection, attachments []string, replacer *strings.Replacer) ([]*mail.File, error) { var files []*mail.File totalSize := int64(0) for _, virtualPath := range replacePathsPlaceholders(attachments, replacer) { info, err := conn.DoStat(virtualPath, 0, false) if err != nil { return nil, fmt.Errorf("unable to get info for file %q, user %q: %w", virtualPath, conn.User.Username, err) } if !info.Mode().IsRegular() { return nil, fmt.Errorf("cannot attach non regular file %q", virtualPath) } totalSize += info.Size() if totalSize > maxAttachmentsSize { return nil, fmt.Errorf("unable to send files as attachment, size too large: %s", util.ByteCountIEC(totalSize)) } files = append(files, &mail.File{ Name: path.Base(virtualPath), Header: make(map[string][]string), Writer: getFileContentFn(conn, virtualPath, info.Size()), }) } return files, nil } func replaceWithReplacer(input string, replacer *strings.Replacer) string { if !strings.Contains(input, "{{.") { return input } return replacer.Replace(input) } func checkEventConditionPattern(p dataprovider.ConditionPattern, name string) bool { var matched bool var err error if strings.Contains(p.Pattern, "**") { matched, err = doublestar.Match(p.Pattern, name) } else { matched, err = path.Match(p.Pattern, name) } if err != nil { eventManagerLog(logger.LevelError, "pattern matching error %q, err: %v", p.Pattern, err) return false } if p.InverseMatch { return !matched } return matched } func checkUserConditionOptions(user *dataprovider.User, conditions *dataprovider.ConditionOptions) bool { if !checkEventConditionPatterns(user.Username, conditions.Names) { return false } if !checkEventConditionPatterns(user.Role, conditions.RoleNames) { return false } if !checkEventGroupConditionPatterns(user.Groups, conditions.GroupNames) { return false } return true } // checkEventConditionPatterns returns false if patterns are defined and no match is found func checkEventConditionPatterns(name string, patterns []dataprovider.ConditionPattern) bool { if len(patterns) == 0 { return true } matches := false for _, p := range patterns { // assume, that multiple InverseMatches are set if p.InverseMatch { if checkEventConditionPattern(p, name) { matches = true } else { return false } } else if checkEventConditionPattern(p, name) { return true } } return matches } func checkEventGroupConditionPatterns(groups []sdk.GroupMapping, patterns []dataprovider.ConditionPattern) bool { if len(patterns) == 0 { return true } matches := false for _, group := range groups { for _, p := range patterns { // assume, that multiple InverseMatches are set if p.InverseMatch { if checkEventConditionPattern(p, group.Name) { matches = true } else { return false } } else { if checkEventConditionPattern(p, group.Name) { return true } } } } return matches } func getHTTPRuleActionEndpoint(c *dataprovider.EventActionHTTPConfig, replacer *strings.Replacer) (string, error) { u, err := url.Parse(c.Endpoint) if err != nil { return "", fmt.Errorf("invalid endpoint: %w", err) } if strings.Contains(u.Path, "{{.") { pathComponents := strings.Split(u.Path, "/") for idx := range pathComponents { part := replaceWithReplacer(pathComponents[idx], replacer) if part != pathComponents[idx] { pathComponents[idx] = url.PathEscape(part) } } u.Path = "" u = u.JoinPath(pathComponents...) } if len(c.QueryParameters) > 0 { q := u.Query() for _, keyVal := range c.QueryParameters { q.Add(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) } u.RawQuery = q.Encode() } return u.String(), nil } func writeHTTPPart(m *multipart.Writer, part dataprovider.HTTPPart, h textproto.MIMEHeader, conn *BaseConnection, replacer *strings.Replacer, params *EventParams, addObjectData bool, ) error { partWriter, err := m.CreatePart(h) if err != nil { eventManagerLog(logger.LevelError, "unable to create part %q, err: %v", part.Name, err) return err } if part.Body != "" { cType := h.Get("Content-Type") if strings.Contains(strings.ToLower(cType), "application/json") { replacements := params.getStringReplacements(addObjectData, 1) jsonReplacer := strings.NewReplacer(replacements...) _, err = partWriter.Write(util.StringToBytes(replaceWithReplacer(part.Body, jsonReplacer))) } else { _, err = partWriter.Write(util.StringToBytes(replaceWithReplacer(part.Body, replacer))) } if err != nil { eventManagerLog(logger.LevelError, "unable to write part %q, err: %v", part.Name, err) return err } return nil } if part.Filepath == dataprovider.RetentionReportPlaceHolder { data, err := params.getCompressedDataRetentionReport() if err != nil { return err } _, err = partWriter.Write(data) if err != nil { eventManagerLog(logger.LevelError, "unable to write part %q, err: %v", part.Name, err) return err } return nil } err = writeFileContent(conn, util.CleanPath(replacer.Replace(part.Filepath)), partWriter) if err != nil { eventManagerLog(logger.LevelError, "unable to write file part %q, err: %v", part.Name, err) return err } return nil } func getHTTPRuleActionBody(c *dataprovider.EventActionHTTPConfig, replacer *strings.Replacer, //nolint:gocyclo cancel context.CancelFunc, user dataprovider.User, params *EventParams, addObjectData bool, ) (io.Reader, string, error) { var body io.Reader if c.Method == http.MethodGet { return body, "", nil } if c.Body != "" { if c.Body == dataprovider.RetentionReportPlaceHolder { data, err := params.getCompressedDataRetentionReport() if err != nil { return body, "", err } return bytes.NewBuffer(data), "", nil } if c.HasJSONBody() { replacements := params.getStringReplacements(addObjectData, 1) jsonReplacer := strings.NewReplacer(replacements...) return bytes.NewBufferString(replaceWithReplacer(c.Body, jsonReplacer)), "", nil } return bytes.NewBufferString(replaceWithReplacer(c.Body, replacer)), "", nil } if len(c.Parts) > 0 { r, w := io.Pipe() m := multipart.NewWriter(w) var conn *BaseConnection if user.Username != "" { var err error if err := getUserForEventAction(&user); err != nil { return body, "", err } connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) err = user.CheckFsRoot(connectionID) if err != nil { user.CloseFs() //nolint:errcheck return body, "", fmt.Errorf("error getting multipart file/s, unable to check root fs for user %q: %w", user.Username, err) } conn = NewBaseConnection(connectionID, protocolEventAction, "", "", user) } go func() { defer w.Close() defer user.CloseFs() //nolint:errcheck if conn != nil { defer conn.CloseFS() //nolint:errcheck } for _, part := range c.Parts { h := make(textproto.MIMEHeader) if part.Body != "" { h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"`, multipartQuoteEscaper.Replace(part.Name))) } else { h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, multipartQuoteEscaper.Replace(part.Name), multipartQuoteEscaper.Replace((path.Base(replaceWithReplacer(part.Filepath, replacer)))))) contentType := mime.TypeByExtension(path.Ext(part.Filepath)) if contentType == "" { contentType = "application/octet-stream" } h.Set("Content-Type", contentType) } for _, keyVal := range part.Headers { h.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) } if err := writeHTTPPart(m, part, h, conn, replacer, params, addObjectData); err != nil { cancel() return } } m.Close() }() return r, m.FormDataContentType(), nil } return body, "", nil } func setHTTPReqHeaders(req *http.Request, c *dataprovider.EventActionHTTPConfig, replacer *strings.Replacer, contentType string, ) { if contentType != "" { req.Header.Set("Content-Type", contentType) } if c.Username != "" || c.Password.GetPayload() != "" { req.SetBasicAuth(replaceWithReplacer(c.Username, replacer), c.Password.GetPayload()) } for _, keyVal := range c.Headers { req.Header.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) } } func executeHTTPRuleAction(c dataprovider.EventActionHTTPConfig, params *EventParams) error { if err := c.TryDecryptPassword(); err != nil { return err } addObjectData := false if params.Object != nil { addObjectData = c.HasObjectData() } replacements := params.getStringReplacements(addObjectData, 0) replacer := strings.NewReplacer(replacements...) endpoint, err := getHTTPRuleActionEndpoint(&c, replacer) if err != nil { return err } ctx, cancel := c.GetContext() defer cancel() var user dataprovider.User if c.HasMultipartFiles() { user, err = params.getUserFromSender() if err != nil { return err } } body, contentType, err := getHTTPRuleActionBody(&c, replacer, cancel, user, params, addObjectData) if err != nil { return err } if body != nil { rc, ok := body.(io.ReadCloser) if ok { defer rc.Close() } } req, err := http.NewRequestWithContext(ctx, c.Method, endpoint, body) if err != nil { return err } setHTTPReqHeaders(req, &c, replacer, contentType) client := c.GetHTTPClient() defer client.CloseIdleConnections() startTime := time.Now() resp, err := client.Do(req) if err != nil { eventManagerLog(logger.LevelDebug, "unable to send http notification, endpoint: %s, elapsed: %s, err: %v", endpoint, time.Since(startTime), err) return fmt.Errorf("error sending HTTP request: %w", err) } defer resp.Body.Close() eventManagerLog(logger.LevelDebug, "http notification sent, endpoint: %s, elapsed: %s, status code: %d", endpoint, time.Since(startTime), resp.StatusCode) if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent { if rb, err := io.ReadAll(io.LimitReader(resp.Body, 2048)); err == nil { eventManagerLog(logger.LevelDebug, "error notification response from endpoint %q: %s", endpoint, rb) } return fmt.Errorf("unexpected status code: %d", resp.StatusCode) } return nil } func executeCommandRuleAction(c dataprovider.EventActionCommandConfig, params *EventParams) error { if !dataprovider.IsActionCommandAllowed(c.Cmd) { return fmt.Errorf("command %q is not allowed", c.Cmd) } addObjectData := false if params.Object != nil { for _, k := range c.EnvVars { if strings.Contains(k.Value, objDataPlaceholder) || strings.Contains(k.Value, objDataPlaceholderString) { addObjectData = true break } } } replacements := params.getStringReplacements(addObjectData, 0) replacer := strings.NewReplacer(replacements...) args := make([]string, 0, len(c.Args)) for _, arg := range c.Args { args = append(args, replaceWithReplacer(arg, replacer)) } ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second) defer cancel() cmd := exec.CommandContext(ctx, c.Cmd, args...) cmd.Env = []string{} for _, keyVal := range c.EnvVars { if keyVal.Value == "$" && !strings.HasPrefix(strings.ToUpper(keyVal.Key), "SFTPGO_") { val := os.Getenv(keyVal.Key) if val == "" { eventManagerLog(logger.LevelDebug, "empty value for environment variable %q", keyVal.Key) } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", keyVal.Key, val)) } else { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))) } } startTime := time.Now() err := cmd.Run() eventManagerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v", c.Cmd, time.Since(startTime), err) return err } func getEmailAddressesWithReplacer(addrs []string, replacer *strings.Replacer) []string { if len(addrs) == 0 { return nil } recipients := make([]string, 0, len(addrs)) for _, recipient := range addrs { rcpt := replaceWithReplacer(recipient, replacer) if rcpt != "" { recipients = append(recipients, rcpt) } } return recipients } func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params *EventParams) error { addObjectData := false if params.Object != nil { if strings.Contains(c.Body, objDataPlaceholder) || strings.Contains(c.Body, objDataPlaceholderString) { addObjectData = true } } replacements := params.getStringReplacements(addObjectData, 0) replacer := strings.NewReplacer(replacements...) var body string if c.ContentType == 1 { replacements := params.getStringReplacements(addObjectData, 2) bodyReplacer := strings.NewReplacer(replacements...) body = replaceWithReplacer(c.Body, bodyReplacer) } else { body = replaceWithReplacer(c.Body, replacer) } subject := replaceWithReplacer(c.Subject, replacer) recipients := getEmailAddressesWithReplacer(c.Recipients, replacer) bcc := getEmailAddressesWithReplacer(c.Bcc, replacer) startTime := time.Now() var files []*mail.File fileAttachments := make([]string, 0, len(c.Attachments)) for _, attachment := range c.Attachments { if attachment == dataprovider.RetentionReportPlaceHolder { f, err := params.getRetentionReportsAsMailAttachment() if err != nil { return err } files = append(files, f) continue } fileAttachments = append(fileAttachments, attachment) } if len(fileAttachments) > 0 { user, err := params.getUserFromSender() if err != nil { return err } if err := getUserForEventAction(&user); err != nil { return err } connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) err = user.CheckFsRoot(connectionID) defer user.CloseFs() //nolint:errcheck if err != nil { return fmt.Errorf("error getting email attachments, unable to check root fs for user %q: %w", user.Username, err) } conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) defer conn.CloseFS() //nolint:errcheck res, err := getMailAttachments(conn, fileAttachments, replacer) if err != nil { return err } files = append(files, res...) } err := smtp.SendEmail(recipients, bcc, subject, body, smtp.EmailContentType(c.ContentType), files...) eventManagerLog(logger.LevelDebug, "executed email notification action, elapsed: %s, error: %v", time.Since(startTime), err) if err != nil { return fmt.Errorf("unable to send email: %w", err) } return nil } func getUserForEventAction(user *dataprovider.User) error { err := user.LoadAndApplyGroupSettings() if err != nil { eventManagerLog(logger.LevelError, "unable to get group for user %q: %+v", user.Username, err) return fmt.Errorf("unable to get groups for user %q", user.Username) } user.UploadDataTransfer = 0 user.UploadBandwidth = 0 user.DownloadBandwidth = 0 user.Filters.DisableFsChecks = false user.Filters.FilePatterns = nil user.Filters.BandwidthLimits = nil for k := range user.Permissions { user.Permissions[k] = []string{dataprovider.PermAny} } return nil } func replacePathsPlaceholders(paths []string, replacer *strings.Replacer) []string { results := make([]string, 0, len(paths)) for _, p := range paths { results = append(results, util.CleanPath(replaceWithReplacer(p, replacer))) } return util.RemoveDuplicates(results, false) } func executeDeleteFileFsAction(conn *BaseConnection, item string, info os.FileInfo) error { fs, fsPath, err := conn.GetFsAndResolvedPath(item) if err != nil { return err } return conn.RemoveFile(fs, fsPath, item, info) } func executeDeleteFsActionForUser(deletes []string, replacer *strings.Replacer, user dataprovider.User) error { if err := getUserForEventAction(&user); err != nil { return err } connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) err := user.CheckFsRoot(connectionID) defer user.CloseFs() //nolint:errcheck if err != nil { return fmt.Errorf("delete error, unable to check root fs for user %q: %w", user.Username, err) } conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) defer conn.CloseFS() //nolint:errcheck for _, item := range replacePathsPlaceholders(deletes, replacer) { info, err := conn.DoStat(item, 0, false) if err != nil { if conn.IsNotExistError(err) { continue } return fmt.Errorf("unable to check item to delete %q, user %q: %w", item, user.Username, err) } if info.IsDir() { if err = conn.RemoveDir(item); err != nil { return fmt.Errorf("unable to remove dir %q, user %q: %w", item, user.Username, err) } } else { if err = executeDeleteFileFsAction(conn, item, info); err != nil { return fmt.Errorf("unable to remove file %q, user %q: %w", item, user.Username, err) } } eventManagerLog(logger.LevelDebug, "item %q removed for user %q", item, user.Username) } return nil } func executeDeleteFsRuleAction(deletes []string, replacer *strings.Replacer, conditions dataprovider.ConditionOptions, params *EventParams, ) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string executed := 0 for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping fs delete for user %s, condition options don't match", user.Username) continue } } executed++ if err = executeDeleteFsActionForUser(deletes, replacer, user); err != nil { params.AddError(err) failures = append(failures, user.Username) } } if len(failures) > 0 { return fmt.Errorf("fs delete failed for users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no delete executed") return errors.New("no delete executed") } return nil } func executeMkDirsFsActionForUser(dirs []string, replacer *strings.Replacer, user dataprovider.User) error { if err := getUserForEventAction(&user); err != nil { return err } connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) err := user.CheckFsRoot(connectionID) defer user.CloseFs() //nolint:errcheck if err != nil { return fmt.Errorf("mkdir error, unable to check root fs for user %q: %w", user.Username, err) } conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) defer conn.CloseFS() //nolint:errcheck for _, item := range replacePathsPlaceholders(dirs, replacer) { if err = conn.CheckParentDirs(path.Dir(item)); err != nil { return fmt.Errorf("unable to check parent dirs for %q, user %q: %w", item, user.Username, err) } if err = conn.createDirIfMissing(item); err != nil { return fmt.Errorf("unable to create dir %q, user %q: %w", item, user.Username, err) } eventManagerLog(logger.LevelDebug, "directory %q created for user %q", item, user.Username) } return nil } func executeMkdirFsRuleAction(dirs []string, replacer *strings.Replacer, conditions dataprovider.ConditionOptions, params *EventParams, ) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string executed := 0 for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping fs mkdir for user %s, condition options don't match", user.Username) continue } } executed++ if err = executeMkDirsFsActionForUser(dirs, replacer, user); err != nil { failures = append(failures, user.Username) } } if len(failures) > 0 { return fmt.Errorf("fs mkdir failed for users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no mkdir executed") return errors.New("no mkdir executed") } return nil } func executeRenameFsActionForUser(renames []dataprovider.RenameConfig, replacer *strings.Replacer, user dataprovider.User, ) error { if err := getUserForEventAction(&user); err != nil { return err } connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) err := user.CheckFsRoot(connectionID) defer user.CloseFs() //nolint:errcheck if err != nil { return fmt.Errorf("rename error, unable to check root fs for user %q: %w", user.Username, err) } conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) defer conn.CloseFS() //nolint:errcheck for _, item := range renames { source := util.CleanPath(replaceWithReplacer(item.Key, replacer)) target := util.CleanPath(replaceWithReplacer(item.Value, replacer)) checks := 0 if item.UpdateModTime { checks += vfs.CheckUpdateModTime } if err = conn.renameInternal(source, target, true, checks); err != nil { return fmt.Errorf("unable to rename %q->%q, user %q: %w", source, target, user.Username, err) } eventManagerLog(logger.LevelDebug, "rename %q->%q ok, user %q", source, target, user.Username) } return nil } func executeCopyFsActionForUser(keyVals []dataprovider.KeyValue, replacer *strings.Replacer, user dataprovider.User, ) error { if err := getUserForEventAction(&user); err != nil { return err } connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) err := user.CheckFsRoot(connectionID) defer user.CloseFs() //nolint:errcheck if err != nil { return fmt.Errorf("copy error, unable to check root fs for user %q: %w", user.Username, err) } conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) defer conn.CloseFS() //nolint:errcheck for _, item := range keyVals { source := util.CleanPath(replaceWithReplacer(item.Key, replacer)) target := util.CleanPath(replaceWithReplacer(item.Value, replacer)) if strings.HasSuffix(item.Key, "/") { source += "/" } if strings.HasSuffix(item.Value, "/") { target += "/" } if err = conn.Copy(source, target); err != nil { return fmt.Errorf("unable to copy %q->%q, user %q: %w", source, target, user.Username, err) } eventManagerLog(logger.LevelDebug, "copy %q->%q ok, user %q", source, target, user.Username) } return nil } func executeExistFsActionForUser(exist []string, replacer *strings.Replacer, user dataprovider.User, ) error { if err := getUserForEventAction(&user); err != nil { return err } connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) err := user.CheckFsRoot(connectionID) defer user.CloseFs() //nolint:errcheck if err != nil { return fmt.Errorf("existence check error, unable to check root fs for user %q: %w", user.Username, err) } conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) defer conn.CloseFS() //nolint:errcheck for _, item := range replacePathsPlaceholders(exist, replacer) { if _, err = conn.DoStat(item, 0, false); err != nil { return fmt.Errorf("error checking existence for path %q, user %q: %w", item, user.Username, err) } eventManagerLog(logger.LevelDebug, "path %q exists for user %q", item, user.Username) } return nil } func executeRenameFsRuleAction(renames []dataprovider.RenameConfig, replacer *strings.Replacer, conditions dataprovider.ConditionOptions, params *EventParams, ) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string executed := 0 for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping fs rename for user %s, condition options don't match", user.Username) continue } } executed++ if err = executeRenameFsActionForUser(renames, replacer, user); err != nil { failures = append(failures, user.Username) params.AddError(err) } } if len(failures) > 0 { return fmt.Errorf("fs rename failed for users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no rename executed") return errors.New("no rename executed") } return nil } func executeCopyFsRuleAction(keyVals []dataprovider.KeyValue, replacer *strings.Replacer, conditions dataprovider.ConditionOptions, params *EventParams, ) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string var executed int for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping fs copy for user %s, condition options don't match", user.Username) continue } } executed++ if err = executeCopyFsActionForUser(keyVals, replacer, user); err != nil { failures = append(failures, user.Username) params.AddError(err) } } if len(failures) > 0 { return fmt.Errorf("fs copy failed for users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no copy executed") return errors.New("no copy executed") } return nil } func getArchiveBaseDir(paths []string) string { var parentDirs []string for _, p := range paths { parentDirs = append(parentDirs, path.Dir(p)) } parentDirs = util.RemoveDuplicates(parentDirs, false) baseDir := "/" if len(parentDirs) == 1 { baseDir = parentDirs[0] } return baseDir } func getSizeForPath(conn *BaseConnection, p string, info os.FileInfo) (int64, error) { if info.IsDir() { var dirSize int64 lister, err := conn.ListDir(p) if err != nil { return 0, err } defer lister.Close() for { entries, err := lister.Next(vfs.ListerBatchSize) finished := errors.Is(err, io.EOF) if err != nil && !finished { return 0, err } for _, entry := range entries { size, err := getSizeForPath(conn, path.Join(p, entry.Name()), entry) if err != nil { return 0, err } dirSize += size } if finished { return dirSize, nil } } } if info.Mode().IsRegular() { return info.Size(), nil } return 0, nil } func estimateZipSize(conn *BaseConnection, zipPath string, paths []string) (int64, error) { q, _ := conn.HasSpace(false, false, zipPath) if q.HasSpace && q.GetRemainingSize() > 0 { var size int64 for _, item := range paths { info, err := conn.DoStat(item, 1, false) if err != nil { return size, err } itemSize, err := getSizeForPath(conn, item, info) if err != nil { return size, err } size += itemSize } eventManagerLog(logger.LevelDebug, "archive paths %v, archive name %q, size: %d", paths, zipPath, size) // we assume the zip size will be half of the real size return size / 2, nil } return -1, nil } func executeCompressFsActionForUser(c dataprovider.EventActionFsCompress, replacer *strings.Replacer, user dataprovider.User, ) error { if err := getUserForEventAction(&user); err != nil { return err } connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) err := user.CheckFsRoot(connectionID) defer user.CloseFs() //nolint:errcheck if err != nil { return fmt.Errorf("compress error, unable to check root fs for user %q: %w", user.Username, err) } conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) defer conn.CloseFS() //nolint:errcheck name := util.CleanPath(replaceWithReplacer(c.Name, replacer)) conn.CheckParentDirs(path.Dir(name)) //nolint:errcheck paths := make([]string, 0, len(c.Paths)) for idx := range c.Paths { p := util.CleanPath(replaceWithReplacer(c.Paths[idx], replacer)) if p == name { return fmt.Errorf("cannot compress the archive to create: %q", name) } paths = append(paths, p) } paths = util.RemoveDuplicates(paths, false) estimatedSize, err := estimateZipSize(conn, name, paths) if err != nil { eventManagerLog(logger.LevelError, "unable to estimate size for archive %q: %v", name, err) return fmt.Errorf("unable to estimate archive size: %w", err) } writer, numFiles, truncatedSize, cancelFn, err := getFileWriter(conn, name, estimatedSize) if err != nil { eventManagerLog(logger.LevelError, "unable to create archive %q: %v", name, err) return fmt.Errorf("unable to create archive: %w", err) } defer cancelFn() baseDir := getArchiveBaseDir(paths) eventManagerLog(logger.LevelDebug, "creating archive %q for paths %+v", name, paths) zipWriter := &zipWriterWrapper{ Name: name, Writer: zip.NewWriter(writer), Entries: make(map[string]bool), } startTime := time.Now() for _, item := range paths { if err := addZipEntry(zipWriter, conn, item, baseDir, nil, 0); err != nil { closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) //nolint:errcheck return err } } if err := zipWriter.Writer.Close(); err != nil { eventManagerLog(logger.LevelError, "unable to close zip file %q: %v", name, err) closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) //nolint:errcheck return fmt.Errorf("unable to close zip file %q: %w", name, err) } return closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) } func executeExistFsRuleAction(exist []string, replacer *strings.Replacer, conditions dataprovider.ConditionOptions, params *EventParams, ) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string executed := 0 for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping fs exist for user %s, condition options don't match", user.Username) continue } } executed++ if err = executeExistFsActionForUser(exist, replacer, user); err != nil { failures = append(failures, user.Username) params.AddError(err) } } if len(failures) > 0 { return fmt.Errorf("fs existence check failed for users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no existence check executed") return errors.New("no existence check executed") } return nil } func executeCompressFsRuleAction(c dataprovider.EventActionFsCompress, replacer *strings.Replacer, conditions dataprovider.ConditionOptions, params *EventParams, ) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string executed := 0 for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping fs compress for user %s, condition options don't match", user.Username) continue } } executed++ if err = executeCompressFsActionForUser(c, replacer, user); err != nil { failures = append(failures, user.Username) params.AddError(err) } } if len(failures) > 0 { return fmt.Errorf("fs compress failed for users: %s", strings.Join(failures, ",")) } if executed == 0 { eventManagerLog(logger.LevelError, "no file/folder compressed") return errors.New("no file/folder compressed") } return nil } func executeFsRuleAction(c dataprovider.EventActionFilesystemConfig, conditions dataprovider.ConditionOptions, params *EventParams, ) error { addObjectData := false replacements := params.getStringReplacements(addObjectData, 0) replacer := strings.NewReplacer(replacements...) switch c.Type { case dataprovider.FilesystemActionRename: return executeRenameFsRuleAction(c.Renames, replacer, conditions, params) case dataprovider.FilesystemActionDelete: return executeDeleteFsRuleAction(c.Deletes, replacer, conditions, params) case dataprovider.FilesystemActionMkdirs: return executeMkdirFsRuleAction(c.MkDirs, replacer, conditions, params) case dataprovider.FilesystemActionExist: return executeExistFsRuleAction(c.Exist, replacer, conditions, params) case dataprovider.FilesystemActionCompress: return executeCompressFsRuleAction(c.Compress, replacer, conditions, params) case dataprovider.FilesystemActionCopy: return executeCopyFsRuleAction(c.Copy, replacer, conditions, params) default: return fmt.Errorf("unsupported filesystem action %d", c.Type) } } func executeQuotaResetForUser(user *dataprovider.User) error { if err := user.LoadAndApplyGroupSettings(); err != nil { eventManagerLog(logger.LevelError, "skipping scheduled quota reset for user %s, cannot apply group settings: %v", user.Username, err) return err } if !QuotaScans.AddUserQuotaScan(user.Username, user.Role) { eventManagerLog(logger.LevelError, "another quota scan is already in progress for user %q", user.Username) return fmt.Errorf("another quota scan is in progress for user %q", user.Username) } defer QuotaScans.RemoveUserQuotaScan(user.Username) numFiles, size, err := user.ScanQuota() if err != nil { eventManagerLog(logger.LevelError, "error scanning quota for user %q: %v", user.Username, err) return fmt.Errorf("error scanning quota for user %q: %w", user.Username, err) } err = dataprovider.UpdateUserQuota(user, numFiles, size, true) if err != nil { eventManagerLog(logger.LevelError, "error updating quota for user %q: %v", user.Username, err) return fmt.Errorf("error updating quota for user %q: %w", user.Username, err) } return nil } func executeUsersQuotaResetRuleAction(conditions dataprovider.ConditionOptions, params *EventParams) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string executed := 0 for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping quota reset for user %q, condition options don't match", user.Username) continue } } executed++ if err = executeQuotaResetForUser(&user); err != nil { params.AddError(err) failures = append(failures, user.Username) } } if len(failures) > 0 { return fmt.Errorf("quota reset failed for users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no user quota reset executed") return errors.New("no user quota reset executed") } return nil } func executeFoldersQuotaResetRuleAction(conditions dataprovider.ConditionOptions, params *EventParams) error { folders, err := params.getFolders() if err != nil { return fmt.Errorf("unable to get folders: %w", err) } var failures []string executed := 0 for _, folder := range folders { // if sender is set, the conditions have already been evaluated if params.sender == "" && !checkEventConditionPatterns(folder.Name, conditions.Names) { eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for folder %s, name conditions don't match", folder.Name) continue } if !QuotaScans.AddVFolderQuotaScan(folder.Name) { eventManagerLog(logger.LevelError, "another quota scan is already in progress for folder %q", folder.Name) params.AddError(fmt.Errorf("another quota scan is already in progress for folder %q", folder.Name)) failures = append(failures, folder.Name) continue } executed++ f := vfs.VirtualFolder{ BaseVirtualFolder: folder, VirtualPath: "/", } numFiles, size, err := f.ScanQuota() QuotaScans.RemoveVFolderQuotaScan(folder.Name) if err != nil { eventManagerLog(logger.LevelError, "error scanning quota for folder %q: %v", folder.Name, err) params.AddError(fmt.Errorf("error scanning quota for folder %q: %w", folder.Name, err)) failures = append(failures, folder.Name) continue } err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true) if err != nil { eventManagerLog(logger.LevelError, "error updating quota for folder %q: %v", folder.Name, err) params.AddError(fmt.Errorf("error updating quota for folder %q: %w", folder.Name, err)) failures = append(failures, folder.Name) } } if len(failures) > 0 { return fmt.Errorf("quota reset failed for folders: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no folder quota reset executed") return errors.New("no folder quota reset executed") } return nil } func executeTransferQuotaResetRuleAction(conditions dataprovider.ConditionOptions, params *EventParams) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string executed := 0 for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping scheduled transfer quota reset for user %s, condition options don't match", user.Username) continue } } executed++ err = dataprovider.UpdateUserTransferQuota(&user, 0, 0, true) if err != nil { eventManagerLog(logger.LevelError, "error updating transfer quota for user %q: %v", user.Username, err) params.AddError(fmt.Errorf("error updating transfer quota for user %q: %w", user.Username, err)) failures = append(failures, user.Username) } } if len(failures) > 0 { return fmt.Errorf("transfer quota reset failed for users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no transfer quota reset executed") return errors.New("no transfer quota reset executed") } return nil } func executeDataRetentionCheckForUser(user dataprovider.User, folders []dataprovider.FolderRetention, params *EventParams, actionName string, ) error { if err := user.LoadAndApplyGroupSettings(); err != nil { eventManagerLog(logger.LevelError, "skipping scheduled retention check for user %s, cannot apply group settings: %v", user.Username, err) return err } check := RetentionCheck{ Folders: folders, } c := RetentionChecks.Add(check, &user) if c == nil { eventManagerLog(logger.LevelError, "another retention check is already in progress for user %q", user.Username) return fmt.Errorf("another retention check is in progress for user %q", user.Username) } defer func() { params.retentionChecks = append(params.retentionChecks, executedRetentionCheck{ Username: user.Username, ActionName: actionName, Results: c.results, }) }() if err := c.Start(); err != nil { eventManagerLog(logger.LevelError, "error checking retention for user %q: %v", user.Username, err) return fmt.Errorf("error checking retention for user %q: %w", user.Username, err) } return nil } func executeDataRetentionCheckRuleAction(config dataprovider.EventActionDataRetentionConfig, conditions dataprovider.ConditionOptions, params *EventParams, actionName string, ) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string executed := 0 for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping scheduled retention check for user %s, condition options don't match", user.Username) continue } } executed++ if err = executeDataRetentionCheckForUser(user, config.Folders, params, actionName); err != nil { failures = append(failures, user.Username) params.AddError(err) } } if len(failures) > 0 { return fmt.Errorf("retention check failed for users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no retention check executed") return errors.New("no retention check executed") } return nil } func executeUserExpirationCheckRuleAction(conditions dataprovider.ConditionOptions, params *EventParams) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string var executed int for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping expiration check for user %q, condition options don't match", user.Username) continue } } executed++ if user.ExpirationDate > 0 { expDate := util.GetTimeFromMsecSinceEpoch(user.ExpirationDate) if expDate.Before(time.Now()) { failures = append(failures, user.Username) } } } if len(failures) > 0 { return fmt.Errorf("expired users: %s", strings.Join(failures, ", ")) } if executed == 0 { eventManagerLog(logger.LevelError, "no user expiration check executed") return errors.New("no user expiration check executed") } return nil } func executeInactivityCheckForUser(user *dataprovider.User, config dataprovider.EventActionUserInactivity, when time.Time) error { if config.DeleteThreshold > 0 && (user.Status == 0 || config.DisableThreshold == 0) { if inactivityDays := user.InactivityDays(when); inactivityDays > config.DeleteThreshold { err := dataprovider.DeleteUser(user.Username, dataprovider.ActionExecutorSystem, "", "") eventManagerLog(logger.LevelInfo, "deleting inactive user %q, days of inactivity: %d/%d, err: %v", user.Username, inactivityDays, config.DeleteThreshold, err) if err != nil { return fmt.Errorf("unable to delete inactive user %q", user.Username) } return fmt.Errorf("inactive user %q deleted. Number of days of inactivity: %d", user.Username, inactivityDays) } } if config.DisableThreshold > 0 && user.Status > 0 { if inactivityDays := user.InactivityDays(when); inactivityDays > config.DisableThreshold { user.Status = 0 err := dataprovider.UpdateUser(user, dataprovider.ActionExecutorSystem, "", "") eventManagerLog(logger.LevelInfo, "disabling inactive user %q, days of inactivity: %d/%d, err: %v", user.Username, inactivityDays, config.DisableThreshold, err) if err != nil { return fmt.Errorf("unable to disable inactive user %q", user.Username) } return fmt.Errorf("inactive user %q disabled. Number of days of inactivity: %d", user.Username, inactivityDays) } } return nil } func executeUserInactivityCheckRuleAction(config dataprovider.EventActionUserInactivity, conditions dataprovider.ConditionOptions, params *EventParams, when time.Time, ) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping inactivity check for user %q, condition options don't match", user.Username) continue } } if err = executeInactivityCheckForUser(&user, config, when); err != nil { params.AddError(err) failures = append(failures, user.Username) } } if len(failures) > 0 { return fmt.Errorf("executed inactivity check actions for users: %s", strings.Join(failures, ", ")) } return nil } func executePwdExpirationCheckForUser(user *dataprovider.User, config dataprovider.EventActionPasswordExpiration) error { if err := user.LoadAndApplyGroupSettings(); err != nil { eventManagerLog(logger.LevelError, "skipping password expiration check for user %q, cannot apply group settings: %v", user.Username, err) return err } if user.ExpirationDate > 0 { if expDate := util.GetTimeFromMsecSinceEpoch(user.ExpirationDate); expDate.Before(time.Now()) { eventManagerLog(logger.LevelDebug, "skipping password expiration check for expired user %q, expiration date: %s", user.Username, expDate) return nil } } if user.Filters.PasswordExpiration == 0 { eventManagerLog(logger.LevelDebug, "password expiration not set for user %q skipping check", user.Username) return nil } days := user.PasswordExpiresIn() if days > config.Threshold { eventManagerLog(logger.LevelDebug, "password for user %q expires in %d days, threshold %d, no need to notify", user.Username, days, config.Threshold) return nil } body := new(bytes.Buffer) data := make(map[string]any) data["Username"] = user.Username data["Days"] = days if err := smtp.RenderPasswordExpirationTemplate(body, data); err != nil { eventManagerLog(logger.LevelError, "unable to notify password expiration for user %s: %v", user.Username, err) return err } subject := "SFTPGo password expiration notification" startTime := time.Now() if err := smtp.SendEmail(user.GetEmailAddresses(), nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil { eventManagerLog(logger.LevelError, "unable to notify password expiration for user %s: %v, elapsed: %s", user.Username, err, time.Since(startTime)) return err } eventManagerLog(logger.LevelDebug, "password expiration email sent to user %s, days: %d, elapsed: %s", user.Username, days, time.Since(startTime)) return nil } func executePwdExpirationCheckRuleAction(config dataprovider.EventActionPasswordExpiration, conditions dataprovider.ConditionOptions, params *EventParams) error { users, err := params.getUsers() if err != nil { return fmt.Errorf("unable to get users: %w", err) } var failures []string for _, user := range users { // if sender is set, the conditions have already been evaluated if params.sender == "" { if !checkUserConditionOptions(&user, &conditions) { eventManagerLog(logger.LevelDebug, "skipping password check for user %q, condition options don't match", user.Username) continue } } if err = executePwdExpirationCheckForUser(&user, config); err != nil { params.AddError(err) failures = append(failures, user.Username) } } if len(failures) > 0 { return fmt.Errorf("password expiration check failed for users: %s", strings.Join(failures, ", ")) } return nil } func executeAdminCheckAction(c *dataprovider.EventActionIDPAccountCheck, params *EventParams) (*dataprovider.Admin, error) { admin, err := dataprovider.AdminExists(params.Name) exists := err == nil if exists && c.Mode == 1 { return &admin, nil } if err != nil && !errors.Is(err, util.ErrNotFound) { return nil, err } replacements := params.getStringReplacements(false, 1) replacer := strings.NewReplacer(replacements...) data := replaceWithReplacer(c.TemplateAdmin, replacer) var newAdmin dataprovider.Admin err = json.Unmarshal(util.StringToBytes(data), &newAdmin) if err != nil { return nil, err } if exists { eventManagerLog(logger.LevelDebug, "updating admin %q after IDP login", params.Name) // Not sure if this makes sense, but it shouldn't hurt. if newAdmin.Password == "" { newAdmin.Password = admin.Password } newAdmin.Filters.TOTPConfig = admin.Filters.TOTPConfig newAdmin.Filters.RecoveryCodes = admin.Filters.RecoveryCodes err = dataprovider.UpdateAdmin(&newAdmin, dataprovider.ActionExecutorSystem, "", "") } else { eventManagerLog(logger.LevelDebug, "creating admin %q after IDP login", params.Name) if newAdmin.Password == "" { newAdmin.Password = util.GenerateUniqueID() } err = dataprovider.AddAdmin(&newAdmin, dataprovider.ActionExecutorSystem, "", "") } return &newAdmin, err } func preserveUserProfile(user, newUser *dataprovider.User) { if newUser.CanChangePassword() && user.Password != "" { newUser.Password = user.Password } if newUser.CanManagePublicKeys() && len(user.PublicKeys) > 0 { newUser.PublicKeys = user.PublicKeys } if newUser.CanManageTLSCerts() { if len(user.Filters.TLSCerts) > 0 { newUser.Filters.TLSCerts = user.Filters.TLSCerts } } if newUser.CanChangeInfo() { if user.Description != "" { newUser.Description = user.Description } if user.Email != "" { newUser.Email = user.Email } if len(user.Filters.AdditionalEmails) > 0 { newUser.Filters.AdditionalEmails = user.Filters.AdditionalEmails } } if newUser.CanChangeAPIKeyAuth() { newUser.Filters.AllowAPIKeyAuth = user.Filters.AllowAPIKeyAuth } newUser.Filters.RecoveryCodes = user.Filters.RecoveryCodes newUser.Filters.TOTPConfig = user.Filters.TOTPConfig newUser.LastPasswordChange = user.LastPasswordChange newUser.SetEmptySecretsIfNil() } func executeUserCheckAction(c *dataprovider.EventActionIDPAccountCheck, params *EventParams) (*dataprovider.User, error) { user, err := dataprovider.UserExists(params.Name, "") exists := err == nil if exists && c.Mode == 1 { err = user.LoadAndApplyGroupSettings() return &user, err } if err != nil && !errors.Is(err, util.ErrNotFound) { return nil, err } replacements := params.getStringReplacements(false, 1) replacer := strings.NewReplacer(replacements...) data := replaceWithReplacer(c.TemplateUser, replacer) var newUser dataprovider.User err = json.Unmarshal(util.StringToBytes(data), &newUser) if err != nil { return nil, err } if exists { eventManagerLog(logger.LevelDebug, "updating user %q after IDP login", params.Name) preserveUserProfile(&user, &newUser) err = dataprovider.UpdateUser(&newUser, dataprovider.ActionExecutorSystem, "", "") } else { eventManagerLog(logger.LevelDebug, "creating user %q after IDP login", params.Name) err = dataprovider.AddUser(&newUser, dataprovider.ActionExecutorSystem, "", "") } if err != nil { return nil, err } u, err := dataprovider.GetUserWithGroupSettings(params.Name, "") return &u, err } func executeRuleAction(action dataprovider.BaseEventAction, params *EventParams, //nolint:gocyclo conditions dataprovider.ConditionOptions, ) error { if len(conditions.EventStatuses) > 0 && !slices.Contains(conditions.EventStatuses, params.Status) { eventManagerLog(logger.LevelDebug, "skipping action %s, event status %d does not match: %v", action.Name, params.Status, conditions.EventStatuses) return nil } var err error switch action.Type { case dataprovider.ActionTypeHTTP: err = executeHTTPRuleAction(action.Options.HTTPConfig, params) case dataprovider.ActionTypeCommand: err = executeCommandRuleAction(action.Options.CmdConfig, params) case dataprovider.ActionTypeEmail: err = executeEmailRuleAction(action.Options.EmailConfig, params) case dataprovider.ActionTypeBackup: var backupPath string backupPath, err = dataprovider.ExecuteBackup() if err == nil { params.setBackupParams(backupPath) } case dataprovider.ActionTypeUserQuotaReset: err = executeUsersQuotaResetRuleAction(conditions, params) case dataprovider.ActionTypeFolderQuotaReset: err = executeFoldersQuotaResetRuleAction(conditions, params) case dataprovider.ActionTypeTransferQuotaReset: err = executeTransferQuotaResetRuleAction(conditions, params) case dataprovider.ActionTypeDataRetentionCheck: err = executeDataRetentionCheckRuleAction(action.Options.RetentionConfig, conditions, params, action.Name) case dataprovider.ActionTypeFilesystem: err = executeFsRuleAction(action.Options.FsConfig, conditions, params) case dataprovider.ActionTypePasswordExpirationCheck: err = executePwdExpirationCheckRuleAction(action.Options.PwdExpirationConfig, conditions, params) case dataprovider.ActionTypeUserExpirationCheck: err = executeUserExpirationCheckRuleAction(conditions, params) case dataprovider.ActionTypeUserInactivityCheck: err = executeUserInactivityCheckRuleAction(action.Options.UserInactivityConfig, conditions, params, time.Now()) case dataprovider.ActionTypeRotateLogs: err = logger.RotateLogFile() default: err = fmt.Errorf("unsupported action type: %d", action.Type) } if err != nil { err = fmt.Errorf("action %q failed: %w", action.Name, err) } params.AddError(err) return err } func executeIDPAccountCheckRule(rule dataprovider.EventRule, params EventParams) (*dataprovider.User, *dataprovider.Admin, error, ) { for _, action := range rule.Actions { if action.Type == dataprovider.ActionTypeIDPAccountCheck { startTime := time.Now() var user *dataprovider.User var admin *dataprovider.Admin var err error var failedActions []string paramsCopy := params.getACopy() switch params.Event { case IDPLoginAdmin: admin, err = executeAdminCheckAction(&action.BaseEventAction.Options.IDPConfig, paramsCopy) case IDPLoginUser: user, err = executeUserCheckAction(&action.BaseEventAction.Options.IDPConfig, paramsCopy) default: err = fmt.Errorf("unsupported IDP login event: %q", params.Event) } if err != nil { paramsCopy.AddError(fmt.Errorf("unable to handle %q: %w", params.Event, err)) eventManagerLog(logger.LevelError, "unable to handle IDP login event %q, err: %v", params.Event, err) failedActions = append(failedActions, action.Name) } else { eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s", action.Name, rule.Name, time.Since(startTime)) } // execute async actions if any, including failure actions go executeRuleAsyncActions(rule, paramsCopy, failedActions) return user, admin, err } } eventManagerLog(logger.LevelError, "no action executed for IDP login event %q, event rule: %q", params.Event, rule.Name) return nil, nil, errors.New("no action executed") } func executeSyncRulesActions(rules []dataprovider.EventRule, params EventParams) error { var errRes error for _, rule := range rules { var failedActions []string paramsCopy := params.getACopy() for _, action := range rule.Actions { if !action.Options.IsFailureAction && action.Options.ExecuteSync { startTime := time.Now() if err := executeRuleAction(action.BaseEventAction, paramsCopy, rule.Conditions.Options); err != nil { eventManagerLog(logger.LevelError, "unable to execute sync action %q for rule %q, elapsed %s, err: %v", action.Name, rule.Name, time.Since(startTime), err) failedActions = append(failedActions, action.Name) // we return the last error, it is ok for now errRes = err if action.Options.StopOnFailure { break } } else { eventManagerLog(logger.LevelDebug, "executed sync action %q for rule %q, elapsed: %s", action.Name, rule.Name, time.Since(startTime)) } } } // execute async actions if any, including failure actions go executeRuleAsyncActions(rule, paramsCopy, failedActions) } return errRes } func executeAsyncRulesActions(rules []dataprovider.EventRule, params EventParams) { eventManager.addAsyncTask() defer eventManager.removeAsyncTask() params.addUID() for _, rule := range rules { executeRuleAsyncActions(rule, params.getACopy(), nil) } } func executeRuleAsyncActions(rule dataprovider.EventRule, params *EventParams, failedActions []string) { for _, action := range rule.Actions { if !action.Options.IsFailureAction && !action.Options.ExecuteSync { startTime := time.Now() if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil { eventManagerLog(logger.LevelError, "unable to execute action %q for rule %q, elapsed %s, err: %v", action.Name, rule.Name, time.Since(startTime), err) failedActions = append(failedActions, action.Name) if action.Options.StopOnFailure { break } } else { eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s", action.Name, rule.Name, time.Since(startTime)) } } } if len(failedActions) > 0 { params.updateStatusFromError = false // execute failure actions for _, action := range rule.Actions { if action.Options.IsFailureAction { startTime := time.Now() if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil { eventManagerLog(logger.LevelError, "unable to execute failure action %q for rule %q, elapsed %s, err: %v", action.Name, rule.Name, time.Since(startTime), err) if action.Options.StopOnFailure { break } } else { eventManagerLog(logger.LevelDebug, "executed failure action %q for rule %q, elapsed: %s", action.Name, rule.Name, time.Since(startTime)) } } } } } type eventCronJob struct { ruleName string } func (j *eventCronJob) getTask(rule *dataprovider.EventRule) (dataprovider.Task, error) { if rule.GuardFromConcurrentExecution() { task, err := dataprovider.GetTaskByName(rule.Name) if err != nil { if errors.Is(err, util.ErrNotFound) { eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name) task = dataprovider.Task{ Name: rule.Name, UpdateAt: 0, Version: 0, } err = dataprovider.AddTask(rule.Name) if err != nil { eventManagerLog(logger.LevelWarn, "unable to add task for rule %q: %v", rule.Name, err) return task, err } } else { eventManagerLog(logger.LevelWarn, "unable to get task for rule %q: %v", rule.Name, err) } } return task, err } return dataprovider.Task{}, nil } func (j *eventCronJob) getEventParams() EventParams { return EventParams{ Event: "Schedule", Name: j.ruleName, Status: 1, Timestamp: time.Now(), updateStatusFromError: true, } } func (j *eventCronJob) Run() { eventManagerLog(logger.LevelDebug, "executing scheduled rule %q", j.ruleName) rule, err := dataprovider.EventRuleExists(j.ruleName) if err != nil { eventManagerLog(logger.LevelError, "unable to load rule with name %q", j.ruleName) return } if err := rule.CheckActionsConsistency(""); err != nil { eventManagerLog(logger.LevelWarn, "scheduled rule %q skipped: %v", rule.Name, err) return } task, err := j.getTask(&rule) if err != nil { return } if task.Name != "" { updateInterval := 5 * time.Minute updatedAt := util.GetTimeFromMsecSinceEpoch(task.UpdateAt) if updatedAt.Add(updateInterval*2 + 1).After(time.Now()) { eventManagerLog(logger.LevelDebug, "task for rule %q too recent: %s, skip execution", rule.Name, updatedAt) return } err = dataprovider.UpdateTask(rule.Name, task.Version) if err != nil { eventManagerLog(logger.LevelInfo, "unable to update task timestamp for rule %q, skip execution, err: %v", rule.Name, err) return } ticker := time.NewTicker(updateInterval) done := make(chan bool) defer func() { done <- true ticker.Stop() }() go func(taskName string) { eventManagerLog(logger.LevelDebug, "update task %q timestamp worker started", taskName) for { select { case <-done: eventManagerLog(logger.LevelDebug, "update task %q timestamp worker finished", taskName) return case <-ticker.C: err := dataprovider.UpdateTaskTimestamp(taskName) eventManagerLog(logger.LevelInfo, "updated timestamp for task %q, err: %v", taskName, err) } } }(task.Name) executeAsyncRulesActions([]dataprovider.EventRule{rule}, j.getEventParams()) } else { executeAsyncRulesActions([]dataprovider.EventRule{rule}, j.getEventParams()) } eventManagerLog(logger.LevelDebug, "execution for scheduled rule %q finished", j.ruleName) } // RunOnDemandRule executes actions for a rule with on-demand trigger func RunOnDemandRule(name string) error { eventManagerLog(logger.LevelDebug, "executing on demand rule %q", name) rule, err := dataprovider.EventRuleExists(name) if err != nil { eventManagerLog(logger.LevelDebug, "unable to load rule with name %q", name) return util.NewRecordNotFoundError(fmt.Sprintf("rule %q does not exist", name)) } if rule.Trigger != dataprovider.EventTriggerOnDemand { eventManagerLog(logger.LevelDebug, "cannot run rule %q as on demand, trigger: %d", name, rule.Trigger) return util.NewValidationError(fmt.Sprintf("rule %q is not defined as on-demand", name)) } if rule.Status != 1 { eventManagerLog(logger.LevelDebug, "on-demand rule %q is inactive", name) return util.NewValidationError(fmt.Sprintf("rule %q is inactive", name)) } if err := rule.CheckActionsConsistency(""); err != nil { eventManagerLog(logger.LevelError, "on-demand rule %q has incompatible actions: %v", name, err) return util.NewValidationError(fmt.Sprintf("rule %q has incosistent actions", name)) } eventManagerLog(logger.LevelDebug, "on-demand rule %q started", name) go executeAsyncRulesActions([]dataprovider.EventRule{rule}, EventParams{Status: 1, updateStatusFromError: true}) return nil } type zipWriterWrapper struct { Name string Entries map[string]bool Writer *zip.Writer } func eventManagerLog(level logger.LogLevel, format string, v ...any) { logger.Log(level, "eventmanager", "", format, v...) } ================================================ FILE: internal/common/eventmanager_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "bytes" "crypto/rand" "errors" "fmt" "io" "mime/multipart" "net/http" "net/url" "os" "path" "path/filepath" "runtime" "strings" "testing" "time" "github.com/klauspost/compress/zip" "github.com/rs/xid" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) func TestEventRuleMatch(t *testing.T) { role := "role1" conditions := &dataprovider.EventConditions{ ProviderEvents: []string{"add", "update"}, Options: dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "user1", InverseMatch: true, }, }, RoleNames: []dataprovider.ConditionPattern{ { Pattern: role, }, }, }, } res := eventManager.checkProviderEventMatch(conditions, &EventParams{ Name: "user1", Role: role, Event: "add", }) assert.False(t, res) res = eventManager.checkProviderEventMatch(conditions, &EventParams{ Name: "user2", Role: role, Event: "update", }) assert.True(t, res) res = eventManager.checkProviderEventMatch(conditions, &EventParams{ Name: "user2", Role: role, Event: "delete", }) assert.False(t, res) conditions.Options.ProviderObjects = []string{"api_key"} res = eventManager.checkProviderEventMatch(conditions, &EventParams{ Name: "user2", Event: "update", Role: role, ObjectType: "share", }) assert.False(t, res) res = eventManager.checkProviderEventMatch(conditions, &EventParams{ Name: "user2", Event: "update", Role: role, ObjectType: "api_key", }) assert.True(t, res) res = eventManager.checkProviderEventMatch(conditions, &EventParams{ Name: "user2", Event: "update", Role: role + "1", ObjectType: "api_key", }) assert.False(t, res) // now test fs events conditions = &dataprovider.EventConditions{ FsEvents: []string{operationUpload, operationDownload}, Options: dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "user*", }, { Pattern: "tester*", }, }, RoleNames: []dataprovider.ConditionPattern{ { Pattern: role, InverseMatch: true, }, }, FsPaths: []dataprovider.ConditionPattern{ { Pattern: "/**/*.txt", }, }, Protocols: []string{ProtocolSFTP}, MinFileSize: 10, MaxFileSize: 30, }, } params := EventParams{ Name: "tester4", Event: operationDelete, VirtualPath: "/path.txt", Protocol: ProtocolSFTP, ObjectName: "path.txt", FileSize: 20, } res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.Event = operationDownload res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.True(t, res) params.Role = role res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.Role = "" params.Name = "name" res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.Name = "user5" res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.True(t, res) params.VirtualPath = "/sub/f.jpg" params.ObjectName = path.Base(params.VirtualPath) res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.VirtualPath = "/sub/f.txt" params.ObjectName = path.Base(params.VirtualPath) res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.True(t, res) params.Protocol = ProtocolHTTP res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.Protocol = ProtocolSFTP params.FileSize = 5 res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.FileSize = 50 res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.FileSize = 25 res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.True(t, res) // bad pattern conditions.Options.Names = []dataprovider.ConditionPattern{ { Pattern: "[-]", }, } res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) // check fs events with group name filters conditions = &dataprovider.EventConditions{ FsEvents: []string{operationUpload, operationDownload}, Options: dataprovider.ConditionOptions{ GroupNames: []dataprovider.ConditionPattern{ { Pattern: "group*", }, { Pattern: "testgroup*", }, }, }, } params = EventParams{ Name: "user1", Event: operationUpload, } res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.Groups = []sdk.GroupMapping{ { Name: "g1", Type: sdk.GroupTypePrimary, }, { Name: "g2", Type: sdk.GroupTypeSecondary, }, } res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.False(t, res) params.Groups = []sdk.GroupMapping{ { Name: "testgroup2", Type: sdk.GroupTypePrimary, }, { Name: "g2", Type: sdk.GroupTypeSecondary, }, } res = eventManager.checkFsEventMatch(conditions, ¶ms) assert.True(t, res) // check user conditions user := dataprovider.User{} user.Username = "u1" res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{}) assert.True(t, res) res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "user", }, }, }) assert.False(t, res) res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{ RoleNames: []dataprovider.ConditionPattern{ { Pattern: role, }, }, }) assert.False(t, res) user.Role = role res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{ RoleNames: []dataprovider.ConditionPattern{ { Pattern: role, }, }, }) assert.True(t, res) res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{ GroupNames: []dataprovider.ConditionPattern{ { Pattern: "group", }, }, RoleNames: []dataprovider.ConditionPattern{ { Pattern: role, }, }, }) assert.False(t, res) res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ IDPLoginEvent: 0, }, &EventParams{ Event: IDPLoginAdmin, }) assert.True(t, res) res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ IDPLoginEvent: 2, }, &EventParams{ Event: IDPLoginAdmin, }) assert.True(t, res) res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ IDPLoginEvent: 1, }, &EventParams{ Event: IDPLoginAdmin, }) assert.False(t, res) res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ IDPLoginEvent: 1, }, &EventParams{ Event: IDPLoginUser, }) assert.True(t, res) res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ IDPLoginEvent: 1, }, &EventParams{ Name: "user", Event: IDPLoginUser, }) assert.True(t, res) res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ IDPLoginEvent: 1, Options: dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "abc", }, }, }, }, &EventParams{ Name: "user", Event: IDPLoginUser, }) assert.False(t, res) res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ IDPLoginEvent: 2, }, &EventParams{ Name: "user", Event: IDPLoginUser, }) assert.False(t, res) } func TestDoubleStarMatching(t *testing.T) { c := dataprovider.ConditionPattern{ Pattern: "/mydir/**", } res := checkEventConditionPattern(c, "/mydir") assert.True(t, res) res = checkEventConditionPattern(c, "/mydirname") assert.False(t, res) res = checkEventConditionPattern(c, "/mydir/sub") assert.True(t, res) res = checkEventConditionPattern(c, "/mydir/sub/dir") assert.True(t, res) c.Pattern = "/**/*" res = checkEventConditionPattern(c, "/mydir") assert.True(t, res) res = checkEventConditionPattern(c, "/mydirname") assert.True(t, res) res = checkEventConditionPattern(c, "/mydir/sub/dir/file.txt") assert.True(t, res) c.Pattern = "/**/*.filepart" res = checkEventConditionPattern(c, "/file.filepart") assert.True(t, res) res = checkEventConditionPattern(c, "/mydir/sub/file.filepart") assert.True(t, res) res = checkEventConditionPattern(c, "/file.txt") assert.False(t, res) res = checkEventConditionPattern(c, "/mydir/file.txt") assert.False(t, res) c.Pattern = "/mydir/**/*.txt" res = checkEventConditionPattern(c, "/mydir") assert.False(t, res) res = checkEventConditionPattern(c, "/mydirname/f.txt") assert.False(t, res) res = checkEventConditionPattern(c, "/mydir/sub") assert.False(t, res) res = checkEventConditionPattern(c, "/mydir/sub/dir") assert.False(t, res) res = checkEventConditionPattern(c, "/mydir/sub/dir/a.txt") assert.True(t, res) c.InverseMatch = true assert.True(t, checkEventConditionPattern(c, "/mydir")) assert.True(t, checkEventConditionPattern(c, "/mydirname/f.txt")) assert.True(t, checkEventConditionPattern(c, "/mydir/sub")) assert.True(t, checkEventConditionPattern(c, "/mydir/sub/dir")) assert.False(t, checkEventConditionPattern(c, "/mydir/sub/dir/a.txt")) } func TestMutlipleDoubleStarMatching(t *testing.T) { patterns := []dataprovider.ConditionPattern{ { Pattern: "/**/*.txt", InverseMatch: false, }, { Pattern: "/**/*.tmp", InverseMatch: false, }, } assert.False(t, checkEventConditionPatterns("/mydir", patterns)) assert.True(t, checkEventConditionPatterns("/mydir/test.tmp", patterns)) assert.True(t, checkEventConditionPatterns("/mydir/test.txt", patterns)) assert.False(t, checkEventConditionPatterns("/mydir/test.csv", patterns)) assert.False(t, checkEventConditionPatterns("/mydir/sub", patterns)) assert.True(t, checkEventConditionPatterns("/mydir/sub/test.tmp", patterns)) assert.True(t, checkEventConditionPatterns("/mydir/sub/test.txt", patterns)) assert.False(t, checkEventConditionPatterns("/mydir/sub/test.csv", patterns)) } func TestMultipleDoubleStarMatchingInverse(t *testing.T) { patterns := []dataprovider.ConditionPattern{ { Pattern: "/**/*.txt", InverseMatch: true, }, { Pattern: "/**/*.tmp", InverseMatch: true, }, } assert.True(t, checkEventConditionPatterns("/mydir", patterns)) assert.False(t, checkEventConditionPatterns("/mydir/test.tmp", patterns)) assert.False(t, checkEventConditionPatterns("/mydir/test.txt", patterns)) assert.True(t, checkEventConditionPatterns("/mydir/test.csv", patterns)) assert.True(t, checkEventConditionPatterns("/mydir/sub", patterns)) assert.False(t, checkEventConditionPatterns("/mydir/sub/test.tmp", patterns)) assert.False(t, checkEventConditionPatterns("/mydir/sub/test.txt", patterns)) assert.True(t, checkEventConditionPatterns("/mydir/sub/test.csv", patterns)) } func TestGroupConditionPatterns(t *testing.T) { group1 := "group1" group2 := "group2" patterns := []dataprovider.ConditionPattern{ { Pattern: group1, }, { Pattern: group2, }, } inversePatterns := []dataprovider.ConditionPattern{ { Pattern: group1, InverseMatch: true, }, { Pattern: group2, InverseMatch: true, }, } groups := []sdk.GroupMapping{ { Name: "group3", Type: sdk.GroupTypePrimary, }, } assert.False(t, checkEventGroupConditionPatterns(groups, patterns)) assert.True(t, checkEventGroupConditionPatterns(groups, inversePatterns)) groups = []sdk.GroupMapping{ { Name: group1, Type: sdk.GroupTypePrimary, }, { Name: "group4", Type: sdk.GroupTypePrimary, }, } assert.True(t, checkEventGroupConditionPatterns(groups, patterns)) assert.False(t, checkEventGroupConditionPatterns(groups, inversePatterns)) groups = []sdk.GroupMapping{ { Name: group1, Type: sdk.GroupTypePrimary, }, } assert.True(t, checkEventGroupConditionPatterns(groups, patterns)) assert.False(t, checkEventGroupConditionPatterns(groups, inversePatterns)) groups = []sdk.GroupMapping{ { Name: "group11", Type: sdk.GroupTypePrimary, }, } assert.False(t, checkEventGroupConditionPatterns(groups, patterns)) assert.True(t, checkEventGroupConditionPatterns(groups, inversePatterns)) } func TestEventManager(t *testing.T) { startEventScheduler() action := &dataprovider.BaseEventAction{ Name: "test_action", Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: "http://localhost", Timeout: 20, Method: http.MethodGet, }, }, } err := dataprovider.AddEventAction(action, "", "", "") assert.NoError(t, err) rule := &dataprovider.EventRule{ Name: "rule", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{operationUpload}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action.Name, }, Order: 1, }, }, } err = dataprovider.AddEventRule(rule, "", "", "") assert.NoError(t, err) eventManager.RLock() assert.Len(t, eventManager.FsEvents, 1) assert.Len(t, eventManager.ProviderEvents, 0) assert.Len(t, eventManager.Schedules, 0) assert.Len(t, eventManager.schedulesMapping, 0) eventManager.RUnlock() rule.Trigger = dataprovider.EventTriggerProviderEvent rule.Conditions = dataprovider.EventConditions{ ProviderEvents: []string{"add"}, } err = dataprovider.UpdateEventRule(rule, "", "", "") assert.NoError(t, err) eventManager.RLock() assert.Len(t, eventManager.FsEvents, 0) assert.Len(t, eventManager.ProviderEvents, 1) assert.Len(t, eventManager.Schedules, 0) assert.Len(t, eventManager.schedulesMapping, 0) eventManager.RUnlock() rule.Trigger = dataprovider.EventTriggerSchedule rule.Conditions = dataprovider.EventConditions{ Schedules: []dataprovider.Schedule{ { Hours: "0", DayOfWeek: "*", DayOfMonth: "*", Month: "*", }, }, } rule.DeletedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-12 * time.Hour)) eventManager.addUpdateRuleInternal(*rule) eventManager.RLock() assert.Len(t, eventManager.FsEvents, 0) assert.Len(t, eventManager.ProviderEvents, 0) assert.Len(t, eventManager.Schedules, 0) assert.Len(t, eventManager.schedulesMapping, 0) eventManager.RUnlock() assert.Eventually(t, func() bool { _, err = dataprovider.EventRuleExists(rule.Name) ok := errors.Is(err, util.ErrNotFound) return ok }, 2*time.Second, 100*time.Millisecond) rule.DeletedAt = 0 err = dataprovider.AddEventRule(rule, "", "", "") assert.NoError(t, err) eventManager.RLock() assert.Len(t, eventManager.FsEvents, 0) assert.Len(t, eventManager.ProviderEvents, 0) assert.Len(t, eventManager.Schedules, 1) assert.Len(t, eventManager.schedulesMapping, 1) eventManager.RUnlock() err = dataprovider.DeleteEventRule(rule.Name, "", "", "") assert.NoError(t, err) eventManager.RLock() assert.Len(t, eventManager.FsEvents, 0) assert.Len(t, eventManager.ProviderEvents, 0) assert.Len(t, eventManager.Schedules, 0) assert.Len(t, eventManager.schedulesMapping, 0) eventManager.RUnlock() err = dataprovider.DeleteEventAction(action.Name, "", "", "") assert.NoError(t, err) stopEventScheduler() } func TestEventManagerErrors(t *testing.T) { startEventScheduler() providerConf := dataprovider.GetProviderConfig() err := dataprovider.Close() assert.NoError(t, err) params := EventParams{ sender: "sender", } _, err = params.getUsers() assert.Error(t, err) _, err = params.getFolders() assert.Error(t, err) err = executeUsersQuotaResetRuleAction(dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeFoldersQuotaResetRuleAction(dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeTransferQuotaResetRuleAction(dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeUserExpirationCheckRuleAction(dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{}, dataprovider.ConditionOptions{}, &EventParams{}, time.Time{}) assert.Error(t, err) err = executeDeleteFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeMkdirFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeRenameFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeExistFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeCopyFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executeCompressFsRuleAction(dataprovider.EventActionFsCompress{}, nil, dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) err = executePwdExpirationCheckRuleAction(dataprovider.EventActionPasswordExpiration{}, dataprovider.ConditionOptions{}, &EventParams{}) assert.Error(t, err) _, err = executeAdminCheckAction(&dataprovider.EventActionIDPAccountCheck{}, &EventParams{}) assert.Error(t, err) _, err = executeUserCheckAction(&dataprovider.EventActionIDPAccountCheck{}, &EventParams{}) assert.Error(t, err) groupName := "agroup" err = executeQuotaResetForUser(&dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }) assert.Error(t, err) err = executeDataRetentionCheckForUser(dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }, nil, &EventParams{}, "") assert.Error(t, err) err = executeDeleteFsActionForUser(nil, nil, dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }) assert.Error(t, err) err = executeMkDirsFsActionForUser(nil, nil, dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }) assert.Error(t, err) err = executeRenameFsActionForUser(nil, nil, dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }) assert.Error(t, err) err = executeExistFsActionForUser(nil, nil, dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }) assert.Error(t, err) err = executeCopyFsActionForUser(nil, nil, dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }) assert.Error(t, err) err = executeCompressFsActionForUser(dataprovider.EventActionFsCompress{}, nil, dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }) assert.Error(t, err) err = executePwdExpirationCheckForUser(&dataprovider.User{ Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }}, dataprovider.EventActionPasswordExpiration{}) assert.Error(t, err) _, _, err = getHTTPRuleActionBody(&dataprovider.EventActionHTTPConfig{ Method: http.MethodPost, Parts: []dataprovider.HTTPPart{ { Name: "p1", }, }, }, nil, nil, dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "u", }, Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, }, &EventParams{}, false) assert.Error(t, err) dataRetentionAction := dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeDataRetentionCheck, Options: dataprovider.BaseEventActionOptions{ RetentionConfig: dataprovider.EventActionDataRetentionConfig{ Folders: []dataprovider.FolderRetention{ { Path: "/", Retention: 24, }, }, }, }, } err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "username1", }, }, }) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to get users") } eventManager.loadRules() eventManager.RLock() assert.Len(t, eventManager.FsEvents, 0) assert.Len(t, eventManager.ProviderEvents, 0) assert.Len(t, eventManager.Schedules, 0) eventManager.RUnlock() // rule with invalid trigger eventManager.addUpdateRuleInternal(dataprovider.EventRule{ Name: "test rule", Status: 1, Trigger: -1, }) eventManager.RLock() assert.Len(t, eventManager.FsEvents, 0) assert.Len(t, eventManager.ProviderEvents, 0) assert.Len(t, eventManager.Schedules, 0) eventManager.RUnlock() // rule with invalid cronspec eventManager.addUpdateRuleInternal(dataprovider.EventRule{ Name: "test rule", Status: 1, Trigger: dataprovider.EventTriggerSchedule, Conditions: dataprovider.EventConditions{ Schedules: []dataprovider.Schedule{ { Hours: "1000", }, }, }, }) eventManager.RLock() assert.Len(t, eventManager.FsEvents, 0) assert.Len(t, eventManager.ProviderEvents, 0) assert.Len(t, eventManager.Schedules, 0) eventManager.RUnlock() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) stopEventScheduler() } func TestDateTimePlaceholder(t *testing.T) { oldTZ := Config.TZ Config.TZ = "" dateTime := time.Now() params := EventParams{ Timestamp: dateTime, } replacements := params.getStringReplacements(false, 0) r := strings.NewReplacer(replacements...) res := r.Replace("{{.DateTime}}") assert.Equal(t, dateTime.UTC().Format(dateTimeMillisFormat), res) res = r.Replace("{{.Year}}-{{.Month}}-{{.Day}}T{{.Hour}}:{{.Minute}}") assert.Equal(t, dateTime.UTC().Format(dateTimeMillisFormat)[:16], res) Config.TZ = "local" replacements = params.getStringReplacements(false, 0) r = strings.NewReplacer(replacements...) res = r.Replace("{{.DateTime}}") assert.Equal(t, dateTime.Local().Format(dateTimeMillisFormat), res) res = r.Replace("{{.Year}}-{{.Month}}-{{.Day}}T{{.Hour}}:{{.Minute}}") assert.Equal(t, dateTime.Local().Format(dateTimeMillisFormat)[:16], res) Config.TZ = oldTZ } func TestEventRuleActions(t *testing.T) { actionName := "test rule action" action := dataprovider.BaseEventAction{ Name: actionName, Type: dataprovider.ActionTypeBackup, } err := executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{}) assert.NoError(t, err) action.Type = -1 err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{}) assert.Error(t, err) action = dataprovider.BaseEventAction{ Name: actionName, Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: "http://foo\x7f.com/", // invalid URL SkipTLSVerify: true, Body: `"data": "{{.ObjectDataString}}"`, Method: http.MethodPost, QueryParameters: []dataprovider.KeyValue{ { Key: "param", Value: "value", }, }, Timeout: 5, Headers: []dataprovider.KeyValue{ { Key: "Content-Type", Value: "application/json", }, }, Username: "httpuser", }, }, } action.Options.SetEmptySecretsIfNil() err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid endpoint") } action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v", httpAddr) params := &EventParams{ Name: "a", Object: &dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test user", }, }, } err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) assert.NoError(t, err) action.Options.HTTPConfig.Method = http.MethodGet err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) assert.NoError(t, err) action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v/404", httpAddr) err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unexpected status code: 404") } action.Options.HTTPConfig.Endpoint = "http://invalid:1234" err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) assert.Error(t, err) action.Options.HTTPConfig.QueryParameters = nil action.Options.HTTPConfig.Endpoint = "http://bar\x7f.com/" err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) assert.Error(t, err) action.Options.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", "data") err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to decrypt HTTP password") } action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v", httpAddr) action.Options.HTTPConfig.Password = kms.NewEmptySecret() action.Options.HTTPConfig.Body = "" action.Options.HTTPConfig.Parts = []dataprovider.HTTPPart{ { Name: "p1", Filepath: "path", }, } err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) assert.Contains(t, getErrorString(err), "error getting user") action.Options.HTTPConfig.Parts = nil action.Options.HTTPConfig.Body = "{{.ObjectData}}" // test disk and transfer quota reset username1 := "user1" username2 := "user2" user1 := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username1, HomeDir: filepath.Join(os.TempDir(), username1), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } user2 := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username2, HomeDir: filepath.Join(os.TempDir(), username2), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } user2.Filters.PasswordExpiration = 10 err = dataprovider.AddUser(&user1, "", "", "") assert.NoError(t, err) err = dataprovider.AddUser(&user2, "", "", "") assert.NoError(t, err) err = executePwdExpirationCheckRuleAction(dataprovider.EventActionPasswordExpiration{ Threshold: 20, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user2.Username, }, }, }, &EventParams{}) // smtp not configured assert.Error(t, err) action = dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeUserQuotaReset, } err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.Error(t, err) // no home dir // create the home dir err = os.MkdirAll(user1.GetHomeDir(), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user1.GetHomeDir(), "file.txt"), []byte("user"), 0666) assert.NoError(t, err) err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.NoError(t, err) userGet, err := dataprovider.UserExists(username1, "") assert.NoError(t, err) assert.Equal(t, 1, userGet.UsedQuotaFiles) assert.Equal(t, int64(4), userGet.UsedQuotaSize) // simulate another quota scan in progress assert.True(t, QuotaScans.AddUserQuotaScan(username1, "")) err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.Error(t, err) assert.True(t, QuotaScans.RemoveUserQuotaScan(username1)) // non matching pattern err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "don't match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no user quota reset executed") action = dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeUserExpirationCheck, } err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "don't match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no user expiration check executed") err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.NoError(t, err) dataRetentionAction := dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeDataRetentionCheck, Options: dataprovider.BaseEventActionOptions{ RetentionConfig: dataprovider.EventActionDataRetentionConfig{ Folders: []dataprovider.FolderRetention{ { Path: "", Retention: 24, }, }, }, }, } err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.Error(t, err) // invalid config, no folder path specified retentionDir := "testretention" dataRetentionAction = dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeDataRetentionCheck, Options: dataprovider.BaseEventActionOptions{ RetentionConfig: dataprovider.EventActionDataRetentionConfig{ Folders: []dataprovider.FolderRetention{ { Path: path.Join("/", retentionDir), Retention: 24, DeleteEmptyDirs: true, }, }, }, }, } // create some test files file1 := filepath.Join(user1.GetHomeDir(), "file1.txt") file2 := filepath.Join(user1.GetHomeDir(), retentionDir, "file2.txt") file3 := filepath.Join(user1.GetHomeDir(), retentionDir, "file3.txt") file4 := filepath.Join(user1.GetHomeDir(), retentionDir, "sub", "file4.txt") err = os.MkdirAll(filepath.Dir(file4), os.ModePerm) assert.NoError(t, err) for _, f := range []string{file1, file2, file3, file4} { err = os.WriteFile(f, []byte(""), 0666) assert.NoError(t, err) } timeBeforeRetention := time.Now().Add(-48 * time.Hour) err = os.Chtimes(file1, timeBeforeRetention, timeBeforeRetention) assert.NoError(t, err) err = os.Chtimes(file2, timeBeforeRetention, timeBeforeRetention) assert.NoError(t, err) err = os.Chtimes(file4, timeBeforeRetention, timeBeforeRetention) assert.NoError(t, err) err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.NoError(t, err) assert.FileExists(t, file1) assert.NoFileExists(t, file2) assert.FileExists(t, file3) assert.NoDirExists(t, filepath.Dir(file4)) // simulate another check in progress c := RetentionChecks.Add(RetentionCheck{}, &user1) assert.NotNil(t, c) err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.Error(t, err) RetentionChecks.remove(user1.Username) err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "no match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no retention check executed") // test file exists action action = dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionExist, Exist: []string{"/file1.txt", path.Join("/", retentionDir, "file3.txt")}, }, }, } err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "no match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no existence check executed") err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.NoError(t, err) action.Options.FsConfig.Exist = []string{"/file1.txt", path.Join("/", retentionDir, "file2.txt")} err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.Error(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) err = dataprovider.UpdateUserTransferQuota(&user1, 100, 100, true) assert.NoError(t, err) action.Type = dataprovider.ActionTypeTransferQuotaReset err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username1, }, }, }) assert.NoError(t, err) userGet, err = dataprovider.UserExists(username1, "") assert.NoError(t, err) assert.Equal(t, int64(0), userGet.UsedDownloadDataTransfer) assert.Equal(t, int64(0), userGet.UsedUploadDataTransfer) err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "no match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no transfer quota reset executed") action.Type = dataprovider.ActionTypeFilesystem action.Options = dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, Renames: []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/source", Value: "/target", }, }, }, }, } err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "no match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no rename executed") action.Options = dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionDelete, Deletes: []string{"/dir1"}, }, } err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "no match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no delete executed") action.Options = dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionMkdirs, Deletes: []string{"/dir1"}, }, } err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "no match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no mkdir executed") action.Options = dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionCompress, Compress: dataprovider.EventActionFsCompress{ Name: "test.zip", Paths: []string{"/{{.VirtualPath}}"}, }, }, } err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "no match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no file/folder compressed") err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ GroupNames: []dataprovider.ConditionPattern{ { Pattern: "no match", }, }, }) assert.Error(t, err) assert.Contains(t, getErrorString(err), "no file/folder compressed") err = dataprovider.DeleteUser(username1, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteUser(username2, "", "", "") assert.NoError(t, err) // test folder quota reset foldername1 := "f1" foldername2 := "f2" folder1 := vfs.BaseVirtualFolder{ Name: foldername1, MappedPath: filepath.Join(os.TempDir(), foldername1), } folder2 := vfs.BaseVirtualFolder{ Name: foldername2, MappedPath: filepath.Join(os.TempDir(), foldername2), } err = dataprovider.AddFolder(&folder1, "", "", "") assert.NoError(t, err) err = dataprovider.AddFolder(&folder2, "", "", "") assert.NoError(t, err) action = dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeFolderQuotaReset, } err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: foldername1, }, }, }) assert.Error(t, err) // no home dir err = os.MkdirAll(folder1.MappedPath, os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(folder1.MappedPath, "file.txt"), []byte("folder"), 0666) assert.NoError(t, err) err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: foldername1, }, }, }) assert.NoError(t, err) folderGet, err := dataprovider.GetFolderByName(foldername1) assert.NoError(t, err) assert.Equal(t, 1, folderGet.UsedQuotaFiles) assert.Equal(t, int64(6), folderGet.UsedQuotaSize) // simulate another quota scan in progress assert.True(t, QuotaScans.AddVFolderQuotaScan(foldername1)) err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: foldername1, }, }, }) assert.Error(t, err) assert.True(t, QuotaScans.RemoveVFolderQuotaScan(foldername1)) err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "no folder match", }, }, }) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no folder quota reset executed") } body, _, err := getHTTPRuleActionBody(&dataprovider.EventActionHTTPConfig{ Method: http.MethodPost, }, nil, nil, dataprovider.User{}, &EventParams{}, true) assert.NoError(t, err) assert.Nil(t, body) body, _, err = getHTTPRuleActionBody(&dataprovider.EventActionHTTPConfig{ Method: http.MethodPost, Body: "test body", }, nil, nil, dataprovider.User{}, &EventParams{}, false) assert.NoError(t, err) assert.NotNil(t, body) err = os.RemoveAll(folder1.MappedPath) assert.NoError(t, err) err = dataprovider.DeleteFolder(foldername1, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteFolder(foldername2, "", "", "") assert.NoError(t, err) } func TestIDPAccountCheckRule(t *testing.T) { _, _, err := executeIDPAccountCheckRule(dataprovider.EventRule{}, EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no action executed") } _, _, err = executeIDPAccountCheckRule(dataprovider.EventRule{ Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "n", Type: dataprovider.ActionTypeIDPAccountCheck, }, }, }, }, EventParams{Event: "invalid"}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported IDP login event") } // invalid json _, err = executeAdminCheckAction(&dataprovider.EventActionIDPAccountCheck{TemplateAdmin: "{"}, &EventParams{Name: "missing admin"}) assert.Error(t, err) _, err = executeUserCheckAction(&dataprovider.EventActionIDPAccountCheck{TemplateUser: "["}, &EventParams{Name: "missing user"}) assert.Error(t, err) _, err = executeUserCheckAction(&dataprovider.EventActionIDPAccountCheck{TemplateUser: "{}"}, &EventParams{Name: "invalid user template"}) assert.ErrorIs(t, err, util.ErrValidation) username := "u" c := &dataprovider.EventActionIDPAccountCheck{ Mode: 1, TemplateUser: `{"username":"` + username + `","status":1,"home_dir":"` + util.JSONEscape(filepath.Join(os.TempDir())) + `","permissions":{"/":["*"]}}`, } params := &EventParams{ Name: username, Event: IDPLoginUser, } user, err := executeUserCheckAction(c, params) assert.NoError(t, err) assert.Equal(t, username, user.Username) assert.Equal(t, 1, user.Status) user.Status = 0 err = dataprovider.UpdateUser(user, "", "", "") assert.NoError(t, err) // the user is not changed user, err = executeUserCheckAction(c, params) assert.NoError(t, err) assert.Equal(t, username, user.Username) assert.Equal(t, 0, user.Status) // change the mode, the user is now updated c.Mode = 0 user, err = executeUserCheckAction(c, params) assert.NoError(t, err) assert.Equal(t, username, user.Username) assert.Equal(t, 1, user.Status) assert.Empty(t, user.Password) assert.Len(t, user.PublicKeys, 0) assert.Len(t, user.Filters.TLSCerts, 0) assert.Empty(t, user.Email) assert.Empty(t, user.Description) // Update the profile attribute and make sure they are preserved user.Password = "secret" user.Email = "example@example.com" user.Filters.AdditionalEmails = []string{"alias@example.com"} user.Description = "some desc" user.Filters.TLSCerts = []string{serverCert} user.PublicKeys = []string{"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"} err = dataprovider.UpdateUser(user, "", "", "") assert.NoError(t, err) user, err = executeUserCheckAction(c, params) assert.NoError(t, err) assert.Equal(t, username, user.Username) assert.Equal(t, 1, user.Status) assert.NotEmpty(t, user.Password) assert.Len(t, user.PublicKeys, 1) assert.Len(t, user.Filters.TLSCerts, 1) assert.NotEmpty(t, user.Email) assert.Len(t, user.Filters.AdditionalEmails, 1) assert.NotEmpty(t, user.Description) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) // check rule consistency r := dataprovider.EventRule{ Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeIDPAccountCheck, }, Order: 1, }, }, } err = r.CheckActionsConsistency("") if assert.Error(t, err) { assert.Contains(t, err.Error(), "IDP account check action is only supported for IDP login trigger") } r.Trigger = dataprovider.EventTriggerIDPLogin err = r.CheckActionsConsistency("") if assert.Error(t, err) { assert.Contains(t, err.Error(), "IDP account check must be a sync action") } r.Actions[0].Options.ExecuteSync = true err = r.CheckActionsConsistency("") assert.NoError(t, err) r.Actions = append(r.Actions, dataprovider.EventAction{ BaseEventAction: dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeCommand, }, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, Order: 2, }) err = r.CheckActionsConsistency("") if assert.Error(t, err) { assert.Contains(t, err.Error(), "IDP account check must be the only sync action") } } func TestUserExpirationCheck(t *testing.T) { username := "test_user_expiration_check" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, HomeDir: filepath.Join(os.TempDir(), username), ExpirationDate: util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)), }, } user.Filters.PasswordExpiration = 5 err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) conditions := dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username, }, }, } err = executeUserExpirationCheckRuleAction(conditions, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "expired users") } // the check will be skipped, the user is expired err = executePwdExpirationCheckRuleAction(dataprovider.EventActionPasswordExpiration{Threshold: 10}, conditions, &EventParams{}) assert.NoError(t, err) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestEventRuleActionsNoGroupMatching(t *testing.T) { username := "test_user_action_group_matching" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, HomeDir: filepath.Join(os.TempDir(), username), }, } err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) conditions := dataprovider.ConditionOptions{ GroupNames: []dataprovider.ConditionPattern{ { Pattern: "agroup", }, }, } err = executeDeleteFsRuleAction(nil, nil, conditions, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no delete executed") } err = executeMkdirFsRuleAction(nil, nil, conditions, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no mkdir executed") } err = executeRenameFsRuleAction(nil, nil, conditions, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no rename executed") } err = executeExistFsRuleAction(nil, nil, conditions, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no existence check executed") } err = executeCopyFsRuleAction(nil, nil, conditions, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no copy executed") } err = executeUsersQuotaResetRuleAction(conditions, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no user quota reset executed") } err = executeTransferQuotaResetRuleAction(conditions, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no transfer quota reset executed") } err = executeDataRetentionCheckRuleAction(dataprovider.EventActionDataRetentionConfig{}, conditions, &EventParams{}, "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "no retention check executed") } err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestGetFileContent(t *testing.T) { username := "test_user_get_file_content" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, HomeDir: filepath.Join(os.TempDir(), username), }, } err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) fileContent := []byte("test file content") err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file.txt"), fileContent, 0666) assert.NoError(t, err) conn := NewBaseConnection(xid.New().String(), protocolEventAction, "", "", user) replacer := strings.NewReplacer("old", "new") files, err := getMailAttachments(conn, []string{"/file.txt"}, replacer) assert.NoError(t, err) if assert.Len(t, files, 1) { var b bytes.Buffer _, err = files[0].Writer(&b) assert.NoError(t, err) assert.Equal(t, fileContent, b.Bytes()) } // missing file _, err = getMailAttachments(conn, []string{"/file1.txt"}, replacer) assert.Error(t, err) // directory _, err = getMailAttachments(conn, []string{"/"}, replacer) assert.Error(t, err) // files too large content := make([]byte, maxAttachmentsSize/2+1) _, err = rand.Read(content) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file1.txt"), content, 0666) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file2.txt"), content, 0666) assert.NoError(t, err) files, err = getMailAttachments(conn, []string{"/file1.txt"}, replacer) assert.NoError(t, err) if assert.Len(t, files, 1) { var b bytes.Buffer _, err = files[0].Writer(&b) assert.NoError(t, err) assert.Equal(t, content, b.Bytes()) } _, err = getMailAttachments(conn, []string{"/file1.txt", "/file2.txt"}, replacer) if assert.Error(t, err) { assert.Contains(t, err.Error(), "size too large") } // change the filesystem provider user.FsConfig.Provider = sdk.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("pwd") err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) conn = NewBaseConnection(xid.New().String(), protocolEventAction, "", "", user) // the file is not encrypted so reading the encryption header will fail files, err = getMailAttachments(conn, []string{"/file.txt"}, replacer) assert.NoError(t, err) if assert.Len(t, files, 1) { var b bytes.Buffer _, err = files[0].Writer(&b) assert.Error(t, err) } err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestFilesystemActionErrors(t *testing.T) { err := executeFsRuleAction(dataprovider.EventActionFilesystemConfig{}, dataprovider.ConditionOptions{}, &EventParams{}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported filesystem action") } username := "test_user_for_actions" testReplacer := strings.NewReplacer("old", "new") user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, HomeDir: filepath.Join(os.TempDir(), username), }, FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: "127.0.0.1:4022", Username: username, }, Password: kms.NewPlainSecret("pwd"), }, }, } err = executeEmailRuleAction(dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.net"}, Subject: "subject", Body: "body", Attachments: []string{"/file.txt"}, }, &EventParams{ sender: username, }) assert.Error(t, err) conn := NewBaseConnection("", protocolEventAction, "", "", user) err = executeDeleteFileFsAction(conn, "", nil) assert.Error(t, err) err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) // check root fs fails err = executeDeleteFsActionForUser(nil, testReplacer, user) assert.Error(t, err) err = executeMkDirsFsActionForUser(nil, testReplacer, user) assert.Error(t, err) err = executeRenameFsActionForUser(nil, testReplacer, user) assert.Error(t, err) err = executeExistFsActionForUser(nil, testReplacer, user) assert.Error(t, err) err = executeCopyFsActionForUser(nil, testReplacer, user) assert.Error(t, err) err = executeCompressFsActionForUser(dataprovider.EventActionFsCompress{}, testReplacer, user) assert.Error(t, err) _, _, _, _, err = getFileWriter(conn, "/path.txt", -1) //nolint:dogsled assert.Error(t, err) err = executeEmailRuleAction(dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.net"}, Subject: "subject", Body: "body", Attachments: []string{"/file1.txt"}, }, &EventParams{ sender: username, }) assert.Error(t, err) fn := getFileContentFn(NewBaseConnection("", protocolEventAction, "", "", user), "/f.txt", 1234) var b bytes.Buffer _, err = fn(&b) assert.Error(t, err) err = executeHTTPRuleAction(dataprovider.EventActionHTTPConfig{ Endpoint: "http://127.0.0.1:9999/", Method: http.MethodPost, Parts: []dataprovider.HTTPPart{ { Name: "p1", Filepath: "/filepath", }, }, }, &EventParams{ sender: username, }) assert.Error(t, err) user.FsConfig.Provider = sdk.LocalFilesystemProvider user.Permissions["/"] = []string{dataprovider.PermUpload} err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) err = executeRenameFsActionForUser([]dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/p1", Value: "/p1", }, }, }, testReplacer, user) if assert.Error(t, err) { assert.Contains(t, err.Error(), "the rename source and target cannot be the same") } err = executeRuleAction(dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, Renames: []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/p2", Value: "/p2", }, }, }, }, }, }, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username, }, }, }) assert.Error(t, err) if runtime.GOOS != osWindows { dirPath := filepath.Join(user.HomeDir, "adir", "sub") err := os.MkdirAll(dirPath, os.ModePerm) assert.NoError(t, err) filePath := filepath.Join(dirPath, "f.dat") err = os.WriteFile(filePath, []byte("test file content"), 0666) assert.NoError(t, err) err = os.Chmod(dirPath, 0001) assert.NoError(t, err) err = executeDeleteFsActionForUser([]string{"/adir/sub"}, testReplacer, user) assert.Error(t, err) err = executeDeleteFsActionForUser([]string{"/adir/sub/f.dat"}, testReplacer, user) assert.Error(t, err) err = os.Chmod(dirPath, 0555) assert.NoError(t, err) err = executeDeleteFsActionForUser([]string{"/adir/sub/f.dat"}, testReplacer, user) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to remove file") } err = executeRuleAction(dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionDelete, Deletes: []string{"/adir/sub/f.dat"}, }, }, }, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username, }, }, }) assert.Error(t, err) err = executeMkDirsFsActionForUser([]string{"/adir/sub/sub"}, testReplacer, user) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to create dir") } err = executeMkDirsFsActionForUser([]string{"/adir/sub/sub/sub"}, testReplacer, user) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to check parent dirs") } err = executeRuleAction(dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionMkdirs, MkDirs: []string{"/adir/sub/sub1"}, }, }, }, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username, }, }, }) assert.Error(t, err) err = os.Chmod(dirPath, os.ModePerm) assert.NoError(t, err) conn = NewBaseConnection("", protocolEventAction, "", "", user) wr := &zipWriterWrapper{ Name: "test.zip", Writer: zip.NewWriter(bytes.NewBuffer(nil)), Entries: map[string]bool{}, } err = addZipEntry(wr, conn, "/adir/sub/f.dat", "/adir/sub/sub", nil, 0) assert.Error(t, err) assert.Contains(t, getErrorString(err), "is outside base dir") } wr := &zipWriterWrapper{ Name: xid.New().String() + ".zip", Writer: zip.NewWriter(bytes.NewBuffer(nil)), Entries: map[string]bool{}, } err = addZipEntry(wr, conn, "/p1", "/", nil, 2000) assert.ErrorIs(t, err, util.ErrRecursionTooDeep) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestQuotaActionsWithQuotaTrackDisabled(t *testing.T) { oldProviderConf := dataprovider.GetProviderConfig() providerConf := dataprovider.GetProviderConfig() providerConf.TrackQuota = 0 err := dataprovider.Close() assert.NoError(t, err) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) username := "u1" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, HomeDir: filepath.Join(os.TempDir(), username), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, FsConfig: vfs.Filesystem{ Provider: sdk.LocalFilesystemProvider, }, } err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeUserQuotaReset}, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username, }, }, }) assert.Error(t, err) err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeTransferQuotaReset}, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: username, }, }, }) assert.Error(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) foldername := "f1" folder := vfs.BaseVirtualFolder{ Name: foldername, MappedPath: filepath.Join(os.TempDir(), foldername), } err = dataprovider.AddFolder(&folder, "", "", "") assert.NoError(t, err) err = os.MkdirAll(folder.MappedPath, os.ModePerm) assert.NoError(t, err) err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeFolderQuotaReset}, &EventParams{}, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: foldername, }, }, }) assert.Error(t, err) err = os.RemoveAll(folder.MappedPath) assert.NoError(t, err) err = dataprovider.DeleteFolder(foldername, "", "", "") assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = dataprovider.Initialize(oldProviderConf, configDir, true) assert.NoError(t, err) } func TestScheduledActions(t *testing.T) { startEventScheduler() backupsPath := filepath.Join(os.TempDir(), "backups") err := os.RemoveAll(backupsPath) assert.NoError(t, err) now := time.Now().UTC().Format(dateTimeMillisFormat) // The backup action sets the home directory to the backup path. expectedDirPath := filepath.Join(backupsPath, fmt.Sprintf("%s_%s_%s", now[0:4], now[5:7], now[8:10])) action1 := &dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeBackup, } err = dataprovider.AddEventAction(action1, "", "", "") assert.NoError(t, err) action2 := &dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionMkdirs, MkDirs: []string{"{{.Year}}_{{.Month}}_{{.Day}}"}, }, }, } err = dataprovider.AddEventAction(action2, "", "", "") assert.NoError(t, err) rule := &dataprovider.EventRule{ Name: "rule", Status: 1, Trigger: dataprovider.EventTriggerSchedule, Conditions: dataprovider.EventConditions{ Schedules: []dataprovider.Schedule{ { Hours: "11", DayOfWeek: "*", DayOfMonth: "*", Month: "*", }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, }, } job := eventCronJob{ ruleName: rule.Name, } job.Run() // rule not found assert.NoDirExists(t, backupsPath) err = dataprovider.AddEventRule(rule, "", "", "") assert.NoError(t, err) job.Run() assert.DirExists(t, backupsPath) assert.DirExists(t, expectedDirPath) action1.Type = dataprovider.ActionTypeEmail action1.Options = dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"example@example.com"}, Subject: "test with attachments", Body: "body", Attachments: []string{"/file1.txt"}, }, } err = dataprovider.UpdateEventAction(action1, "", "", "") assert.NoError(t, err) job.Run() // action is not compatible with a scheduled rule err = dataprovider.DeleteEventRule(rule.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteEventAction(action1.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteEventAction(action2.Name, "", "", "") assert.NoError(t, err) err = os.RemoveAll(backupsPath) assert.NoError(t, err) stopEventScheduler() } func TestEventParamsCopy(t *testing.T) { params := EventParams{ Name: "name", Event: "event", Extension: "ext", Status: 1, errors: []string{"error1"}, retentionChecks: []executedRetentionCheck{}, } paramsCopy := params.getACopy() assert.Equal(t, params, *paramsCopy) params.Name = "name mod" paramsCopy.Event = "event mod" paramsCopy.Status = 2 params.errors = append(params.errors, "error2") paramsCopy.errors = append(paramsCopy.errors, "error3") assert.Equal(t, []string{"error1", "error3"}, paramsCopy.errors) assert.Equal(t, []string{"error1", "error2"}, params.errors) assert.Equal(t, "name mod", params.Name) assert.Equal(t, "name", paramsCopy.Name) assert.Equal(t, "event", params.Event) assert.Equal(t, "event mod", paramsCopy.Event) assert.Equal(t, 1, params.Status) assert.Equal(t, 2, paramsCopy.Status) params = EventParams{ retentionChecks: []executedRetentionCheck{ { Username: "u", ActionName: "a", Results: []folderRetentionCheckResult{ { Path: "p", Retention: 1, }, }, }, }, } paramsCopy = params.getACopy() require.Len(t, paramsCopy.retentionChecks, 1) paramsCopy.retentionChecks[0].Username = "u_copy" paramsCopy.retentionChecks[0].ActionName = "a_copy" require.Len(t, paramsCopy.retentionChecks[0].Results, 1) paramsCopy.retentionChecks[0].Results[0].Path = "p_copy" paramsCopy.retentionChecks[0].Results[0].Retention = 2 assert.Equal(t, "u", params.retentionChecks[0].Username) assert.Equal(t, "a", params.retentionChecks[0].ActionName) assert.Equal(t, "p", params.retentionChecks[0].Results[0].Path) assert.Equal(t, 1, params.retentionChecks[0].Results[0].Retention) assert.Equal(t, "u_copy", paramsCopy.retentionChecks[0].Username) assert.Equal(t, "a_copy", paramsCopy.retentionChecks[0].ActionName) assert.Equal(t, "p_copy", paramsCopy.retentionChecks[0].Results[0].Path) assert.Equal(t, 2, paramsCopy.retentionChecks[0].Results[0].Retention) assert.Nil(t, params.IDPCustomFields) params.addIDPCustomFields(nil) assert.Nil(t, params.IDPCustomFields) params.IDPCustomFields = &map[string]string{ "field1": "val1", } paramsCopy = params.getACopy() for k, v := range *paramsCopy.IDPCustomFields { assert.Equal(t, "field1", k) assert.Equal(t, "val1", v) } assert.Equal(t, params.IDPCustomFields, paramsCopy.IDPCustomFields) (*paramsCopy.IDPCustomFields)["field1"] = "val2" assert.NotEqual(t, params.IDPCustomFields, paramsCopy.IDPCustomFields) params.Metadata = map[string]string{"key": "value"} paramsCopy = params.getACopy() params.Metadata["key1"] = "value1" require.Equal(t, map[string]string{"key": "value"}, paramsCopy.Metadata) } func TestEventParamsStatusFromError(t *testing.T) { params := EventParams{Status: 1} params.AddError(os.ErrNotExist) assert.Equal(t, 1, params.Status) params = EventParams{Status: 1, updateStatusFromError: true} params.AddError(os.ErrNotExist) assert.Equal(t, 2, params.Status) } type testWriter struct { errTest error sentinel string } func (w *testWriter) Write(p []byte) (int, error) { if w.errTest != nil { return 0, w.errTest } if w.sentinel == string(p) { return 0, io.ErrUnexpectedEOF } return len(p), nil } func TestWriteHTTPPartsError(t *testing.T) { m := multipart.NewWriter(&testWriter{ errTest: io.ErrShortWrite, }) err := writeHTTPPart(m, dataprovider.HTTPPart{}, nil, nil, nil, &EventParams{}, false) assert.ErrorIs(t, err, io.ErrShortWrite) body := "test body" m = multipart.NewWriter(&testWriter{sentinel: body}) err = writeHTTPPart(m, dataprovider.HTTPPart{ Body: body, }, nil, nil, nil, &EventParams{}, false) assert.ErrorIs(t, err, io.ErrUnexpectedEOF) } func TestReplacePathsPlaceholders(t *testing.T) { replacer := strings.NewReplacer("{{.VirtualPath}}", "/path1") paths := []string{"{{.VirtualPath}}", "/path1"} paths = replacePathsPlaceholders(paths, replacer) assert.Equal(t, []string{"/path1"}, paths) paths = []string{"{{.VirtualPath}}", "/path2"} paths = replacePathsPlaceholders(paths, replacer) assert.Equal(t, []string{"/path1", "/path2"}, paths) } func TestEstimateZipSizeErrors(t *testing.T) { u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "u", HomeDir: filepath.Join(os.TempDir(), "u"), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, QuotaSize: 1000, }, } err := dataprovider.AddUser(&u, "", "", "") assert.NoError(t, err) err = os.MkdirAll(u.GetHomeDir(), os.ModePerm) assert.NoError(t, err) conn := NewBaseConnection("", ProtocolFTP, "", "", u) _, _, _, _, err = getFileWriter(conn, "/missing/path/file.txt", -1) //nolint:dogsled assert.Error(t, err) _, err = getSizeForPath(conn, "/missing", vfs.NewFileInfo("missing", true, 0, time.Now(), false)) assert.True(t, conn.IsNotExistError(err)) if runtime.GOOS != osWindows { err = os.MkdirAll(filepath.Join(u.HomeDir, "d1", "d2", "sub"), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(u.HomeDir, "d1", "d2", "sub", "file.txt"), []byte("data"), 0666) assert.NoError(t, err) err = os.Chmod(filepath.Join(u.HomeDir, "d1", "d2"), 0001) assert.NoError(t, err) size, err := estimateZipSize(conn, "/archive.zip", []string{"/d1"}) assert.Error(t, err, "size %d", size) err = os.Chmod(filepath.Join(u.HomeDir, "d1", "d2"), os.ModePerm) assert.NoError(t, err) } err = dataprovider.DeleteUser(u.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(u.GetHomeDir()) assert.NoError(t, err) } func TestOnDemandRule(t *testing.T) { a := &dataprovider.BaseEventAction{ Name: "a", Type: dataprovider.ActionTypeBackup, Options: dataprovider.BaseEventActionOptions{}, } err := dataprovider.AddEventAction(a, "", "", "") assert.NoError(t, err) r := &dataprovider.EventRule{ Name: "test on demand rule", Status: 1, Trigger: dataprovider.EventTriggerOnDemand, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: a.Name, }, }, }, } err = dataprovider.AddEventRule(r, "", "", "") assert.NoError(t, err) err = RunOnDemandRule(r.Name) assert.NoError(t, err) r.Status = 0 err = dataprovider.UpdateEventRule(r, "", "", "") assert.NoError(t, err) err = RunOnDemandRule(r.Name) assert.ErrorIs(t, err, util.ErrValidation) assert.Contains(t, err.Error(), "is inactive") r.Status = 1 r.Trigger = dataprovider.EventTriggerCertificate err = dataprovider.UpdateEventRule(r, "", "", "") assert.NoError(t, err) err = RunOnDemandRule(r.Name) assert.ErrorIs(t, err, util.ErrValidation) assert.Contains(t, err.Error(), "is not defined as on-demand") a1 := &dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"example@example.org"}, Subject: "subject", Body: "body", Attachments: []string{"/{{.VirtualPath}}"}, }, }, } err = dataprovider.AddEventAction(a1, "", "", "") assert.NoError(t, err) r.Trigger = dataprovider.EventTriggerOnDemand r.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: a1.Name, }, }, } err = dataprovider.UpdateEventRule(r, "", "", "") assert.NoError(t, err) err = RunOnDemandRule(r.Name) assert.ErrorIs(t, err, util.ErrValidation) assert.Contains(t, err.Error(), "incosistent actions") err = dataprovider.DeleteEventRule(r.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteEventAction(a.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteEventAction(a1.Name, "", "", "") assert.NoError(t, err) err = RunOnDemandRule(r.Name) assert.ErrorIs(t, err, util.ErrNotFound) } func getErrorString(err error) string { if err == nil { return "" } return err.Error() } func TestHTTPEndpointWithPlaceholders(t *testing.T) { c := dataprovider.EventActionHTTPConfig{ Endpoint: "http://127.0.0.1:8080/base/url/{{.Name}}/{{.VirtualPath}}/upload", QueryParameters: []dataprovider.KeyValue{ { Key: "u", Value: "{{.Name}}", }, { Key: "p", Value: "{{.VirtualPath}}", }, }, } name := "uname" vPath := "/a dir/@ file.txt" replacer := strings.NewReplacer("{{.Name}}", name, "{{.VirtualPath}}", vPath) u, err := getHTTPRuleActionEndpoint(&c, replacer) assert.NoError(t, err) expected := "http://127.0.0.1:8080/base/url/" + url.PathEscape(name) + "/" + url.PathEscape(vPath) + "/upload?" + "p=" + url.QueryEscape(vPath) + "&u=" + url.QueryEscape(name) assert.Equal(t, expected, u) c.Endpoint = "http://127.0.0.1/upload" u, err = getHTTPRuleActionEndpoint(&c, replacer) assert.NoError(t, err) expected = c.Endpoint + "?p=" + url.QueryEscape(vPath) + "&u=" + url.QueryEscape(name) assert.Equal(t, expected, u) } func TestMetadataReplacement(t *testing.T) { params := &EventParams{ Metadata: map[string]string{ "key": "value", }, } replacements := params.getStringReplacements(false, 0) replacer := strings.NewReplacer(replacements...) reader, _, err := getHTTPRuleActionBody(&dataprovider.EventActionHTTPConfig{Body: "{{.Metadata}} {{.MetadataString}}"}, replacer, nil, dataprovider.User{}, params, false) require.NoError(t, err) data, err := io.ReadAll(reader) require.NoError(t, err) assert.Equal(t, `{"key":"value"} {\"key\":\"value\"}`, string(data)) } func TestUserInactivityCheck(t *testing.T) { username1 := "user1" username2 := "user2" user1 := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username1, HomeDir: filepath.Join(os.TempDir(), username1), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } user2 := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username2, HomeDir: filepath.Join(os.TempDir(), username2), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } days := user1.InactivityDays(time.Now().Add(10*24*time.Hour + 5*time.Second)) assert.Equal(t, 0, days) user2.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) err := executeInactivityCheckForUser(&user2, dataprovider.EventActionUserInactivity{ DisableThreshold: 10, }, time.Now().Add(12*24*time.Hour)) assert.Error(t, err) user2.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) err = executeInactivityCheckForUser(&user2, dataprovider.EventActionUserInactivity{ DeleteThreshold: 10, }, time.Now().Add(12*24*time.Hour)) assert.Error(t, err) err = dataprovider.AddUser(&user1, "", "", "") assert.NoError(t, err) err = dataprovider.AddUser(&user2, "", "", "") assert.NoError(t, err) user1, err = dataprovider.UserExists(username1, "") assert.NoError(t, err) assert.Equal(t, 1, user1.Status) days = user1.InactivityDays(time.Now().Add(10*24*time.Hour + 5*time.Second)) assert.Equal(t, 10, days) days = user1.InactivityDays(time.Now().Add(-10*24*time.Hour + 5*time.Second)) assert.Equal(t, -9, days) err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DisableThreshold: 10, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "not matching", }, }, }, &EventParams{}, time.Now().Add(12*24*time.Hour)) assert.NoError(t, err) err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DisableThreshold: 10, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user1.Username, }, }, }, &EventParams{}, time.Now()) assert.NoError(t, err) // no action err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DisableThreshold: 10, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user1.Username, }, }, }, &EventParams{}, time.Now().Add(-12*24*time.Hour)) assert.NoError(t, err) // no action err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DisableThreshold: 10, DeleteThreshold: 20, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user1.Username, }, }, }, &EventParams{}, time.Now().Add(30*24*time.Hour)) // both thresholds exceeded, the user will be disabled if assert.Error(t, err) { assert.Contains(t, err.Error(), "executed inactivity check actions for users") } user1, err = dataprovider.UserExists(username1, "") assert.NoError(t, err) assert.Equal(t, 0, user1.Status) err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DisableThreshold: 10, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user1.Username, }, }, }, &EventParams{}, time.Now().Add(30*24*time.Hour)) assert.NoError(t, err) // already disabled, no action err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DisableThreshold: 10, DeleteThreshold: 20, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user1.Username, }, }, }, &EventParams{}, time.Now().Add(-30*24*time.Hour)) assert.NoError(t, err) err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DisableThreshold: 10, DeleteThreshold: 20, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user1.Username, }, }, }, &EventParams{}, time.Now()) assert.NoError(t, err) user1, err = dataprovider.UserExists(username1, "") assert.NoError(t, err) assert.Equal(t, 0, user1.Status) err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DisableThreshold: 10, DeleteThreshold: 20, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user1.Username, }, }, }, &EventParams{}, time.Now().Add(30*24*time.Hour)) // the user is disabled, will be now deleted assert.Error(t, err) _, err = dataprovider.UserExists(username1, "") assert.ErrorIs(t, err, util.ErrNotFound) err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ DeleteThreshold: 20, }, dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user2.Username, }, }, }, &EventParams{}, time.Now().Add(30*24*time.Hour)) // no disable threshold, user deleted assert.Error(t, err) _, err = dataprovider.UserExists(username2, "") assert.ErrorIs(t, err, util.ErrNotFound) err = dataprovider.DeleteUser(username1, "", "", "") assert.Error(t, err) err = dataprovider.DeleteUser(username2, "", "", "") assert.Error(t, err) } ================================================ FILE: internal/common/eventscheduler.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "time" "github.com/robfig/cron/v3" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( eventScheduler *cron.Cron ) func stopEventScheduler() { if eventScheduler != nil { eventScheduler.Stop() eventScheduler = nil } } func startEventScheduler() { stopEventScheduler() options := []cron.Option{ cron.WithLogger(cron.DiscardLogger), } if !dataprovider.UseLocalTime() { eventManagerLog(logger.LevelDebug, "use UTC time for the scheduler") options = append(options, cron.WithLocation(time.UTC)) } eventScheduler = cron.New(options...) eventManager.loadRules() _, err := eventScheduler.AddFunc("@every 10m", eventManager.loadRules) util.PanicOnError(err) eventScheduler.Start() } ================================================ FILE: internal/common/httpauth.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "encoding/csv" "os" "strings" "sync" "github.com/GehirnInc/crypt/apr1_crypt" "github.com/GehirnInc/crypt/md5_crypt" "golang.org/x/crypto/bcrypt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( // HTTPAuthenticationHeader defines the HTTP authentication HTTPAuthenticationHeader = "WWW-Authenticate" md5CryptPwdPrefix = "$1$" apr1CryptPwdPrefix = "$apr1$" ) var ( bcryptPwdPrefixes = []string{"$2a$", "$2$", "$2x$", "$2y$", "$2b$"} ) // HTTPAuthProvider defines the interface for HTTP auth providers type HTTPAuthProvider interface { ValidateCredentials(username, password string) bool IsEnabled() bool } type basicAuthProvider struct { Path string sync.RWMutex Info os.FileInfo Users map[string]string } // NewBasicAuthProvider returns an HTTPAuthProvider implementing Basic Auth func NewBasicAuthProvider(authUserFile string) (HTTPAuthProvider, error) { basicAuthProvider := basicAuthProvider{ Path: authUserFile, Info: nil, Users: make(map[string]string), } return &basicAuthProvider, basicAuthProvider.loadUsers() } func (p *basicAuthProvider) IsEnabled() bool { return p.Path != "" } func (p *basicAuthProvider) isReloadNeeded(info os.FileInfo) bool { p.RLock() defer p.RUnlock() return p.Info == nil || p.Info.ModTime() != info.ModTime() || p.Info.Size() != info.Size() } func (p *basicAuthProvider) loadUsers() error { if !p.IsEnabled() { return nil } info, err := os.Stat(p.Path) if err != nil { logger.Debug(logSender, "", "unable to stat basic auth users file: %v", err) return err } if p.isReloadNeeded(info) { r, err := os.Open(p.Path) if err != nil { logger.Debug(logSender, "", "unable to open basic auth users file: %v", err) return err } defer r.Close() reader := csv.NewReader(r) reader.Comma = ':' reader.Comment = '#' reader.TrimLeadingSpace = true records, err := reader.ReadAll() if err != nil { logger.Debug(logSender, "", "unable to parse basic auth users file: %v", err) return err } p.Lock() defer p.Unlock() p.Users = make(map[string]string) for _, record := range records { if len(record) == 2 { p.Users[record[0]] = record[1] } } logger.Debug(logSender, "", "number of users loaded for httpd basic auth: %v", len(p.Users)) p.Info = info } return nil } func (p *basicAuthProvider) getHashedPassword(username string) (string, bool) { err := p.loadUsers() if err != nil { return "", false } p.RLock() defer p.RUnlock() pwd, ok := p.Users[username] return pwd, ok } // ValidateCredentials returns true if the credentials are valid func (p *basicAuthProvider) ValidateCredentials(username, password string) bool { if hashedPwd, ok := p.getHashedPassword(username); ok { if util.IsStringPrefixInSlice(hashedPwd, bcryptPwdPrefixes) { err := bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(password)) return err == nil } if strings.HasPrefix(hashedPwd, md5CryptPwdPrefix) { crypter := md5_crypt.New() err := crypter.Verify(hashedPwd, []byte(password)) return err == nil } if strings.HasPrefix(hashedPwd, apr1CryptPwdPrefix) { crypter := apr1_crypt.New() err := crypter.Verify(hashedPwd, []byte(password)) return err == nil } } return false } ================================================ FILE: internal/common/httpauth_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "os" "path/filepath" "runtime" "testing" "github.com/stretchr/testify/require" ) func TestBasicAuth(t *testing.T) { httpAuth, err := NewBasicAuthProvider("") require.NoError(t, err) require.False(t, httpAuth.IsEnabled()) _, err = NewBasicAuthProvider("missing path") require.Error(t, err) authUserFile := filepath.Join(os.TempDir(), "http_users.txt") authUserData := []byte("test1:$2y$05$bcHSED7aO1cfLto6ZdDBOOKzlwftslVhtpIkRhAtSa4GuLmk5mola\n") err = os.WriteFile(authUserFile, authUserData, os.ModePerm) require.NoError(t, err) httpAuth, err = NewBasicAuthProvider(authUserFile) require.NoError(t, err) require.True(t, httpAuth.IsEnabled()) require.False(t, httpAuth.ValidateCredentials("test1", "wrong1")) require.False(t, httpAuth.ValidateCredentials("test2", "password2")) require.True(t, httpAuth.ValidateCredentials("test1", "password1")) authUserData = append(authUserData, []byte("test2:$1$OtSSTL8b$bmaCqEksI1e7rnZSjsIDR1\n")...) err = os.WriteFile(authUserFile, authUserData, os.ModePerm) require.NoError(t, err) require.False(t, httpAuth.ValidateCredentials("test2", "wrong2")) require.True(t, httpAuth.ValidateCredentials("test2", "password2")) authUserData = append(authUserData, []byte("test2:$apr1$gLnIkRIf$Xr/6aJfmIrihP4b2N2tcs/\n")...) err = os.WriteFile(authUserFile, authUserData, os.ModePerm) require.NoError(t, err) require.False(t, httpAuth.ValidateCredentials("test2", "wrong2")) require.True(t, httpAuth.ValidateCredentials("test2", "password2")) authUserData = append(authUserData, []byte("test3:$apr1$gLnIkRIf$Xr/6aJfmIrihP4b2N2tcs/\n")...) err = os.WriteFile(authUserFile, authUserData, os.ModePerm) require.NoError(t, err) require.False(t, httpAuth.ValidateCredentials("test3", "password3")) authUserData = append(authUserData, []byte("test4:$invalid$gLnIkRIf$Xr/6$aJfmIr$ihP4b2N2tcs/\n")...) err = os.WriteFile(authUserFile, authUserData, os.ModePerm) require.NoError(t, err) require.False(t, httpAuth.ValidateCredentials("test4", "password3")) if runtime.GOOS != "windows" { authUserData = append(authUserData, []byte("test5:$apr1$gLnIkRIf$Xr/6aJfmIrihP4b2N2tcs/\n")...) err = os.WriteFile(authUserFile, authUserData, os.ModePerm) require.NoError(t, err) err = os.Chmod(authUserFile, 0001) require.NoError(t, err) require.False(t, httpAuth.ValidateCredentials("test5", "password2")) err = os.Chmod(authUserFile, os.ModePerm) require.NoError(t, err) } authUserData = append(authUserData, []byte("\"foo\"bar\"\r\n")...) err = os.WriteFile(authUserFile, authUserData, os.ModePerm) require.NoError(t, err) require.False(t, httpAuth.ValidateCredentials("test2", "password2")) err = os.Remove(authUserFile) require.NoError(t, err) } ================================================ FILE: internal/common/protocol_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common_test import ( "bufio" "bytes" "crypto/rand" "encoding/json" "errors" "fmt" "io" "io/fs" "math" "net" "net/http" "net/url" "os" "path" "path/filepath" "runtime" "slices" "strings" "sync" "testing" "time" _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v5/stdlib" _ "github.com/mattn/go-sqlite3" "github.com/mhale/smtpd" "github.com/minio/sio" "github.com/pkg/sftp" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/studio-b12/gowebdav" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/httpdtest" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) const ( httpAddr = "127.0.0.1:9999" httpProxyAddr = "127.0.0.1:7777" sftpServerAddr = "127.0.0.1:4022" smtpServerAddr = "127.0.0.1:2525" webDavServerPort = 9191 httpFsPort = 34567 defaultUsername = "test_common_sftp" defaultPassword = "test_password" defaultSFTPUsername = "test_common_sftpfs_user" defaultHTTPFsUsername = "httpfs_ftp_user" httpFsWellKnowDir = "/wellknow" osWindows = "windows" testFileName = "test_file_common_sftp.dat" testDir = "test_dir_common" testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn NhAAAAAwEAAQAAAYEAtN449A/nY5O6cSH/9Doa8a3ISU0WZJaHydTaCLuO+dkqtNpnV5mq zFbKidXAI1eSwVctw9ReVOl1uK6aZF3lbXdOD8W9PXobR9KUUT2qBx5QC4ibfAqDKWymDA PG9ylzz64hsYBqJr7VNk9kTFEUsDmWzLabLoH42Elnp8mF/lTkWIcpVp0ly/etS08gttXo XenekJ1vRuxOYWDCEzGPU7kGc920TmM14k7IDdPoOh5+3sRUKedKeOUrVDH1f0n7QjHQsZ cbshp8tgqzf734zu8cTqNrr+6taptdEOOij1iUL/qYGfzny/hA48tO5+UFUih5W8ftp0+E NBIDkkGgk2MJ92I7QAXyMVsIABXco+mJT7pQi9tqlODGIQ3AOj0gcA3X/Ib8QX77Ih3TPi XEh77/P1XiYZOgpp2cRmNH8QbqaL9u898hDvJwIPJPuj2lIltTElH7hjBf5LQfCzrLV7BD 10rM7sl4jr+A2q8jl1Ikp+25kainBBZSbrDummT9AAAFgDU/VLk1P1S5AAAAB3NzaC1yc2 EAAAGBALTeOPQP52OTunEh//Q6GvGtyElNFmSWh8nU2gi7jvnZKrTaZ1eZqsxWyonVwCNX ksFXLcPUXlTpdbiummRd5W13Tg/FvT16G0fSlFE9qgceUAuIm3wKgylspgwDxvcpc8+uIb GAaia+1TZPZExRFLA5lsy2my6B+NhJZ6fJhf5U5FiHKVadJcv3rUtPILbV6F3p3pCdb0bs TmFgwhMxj1O5BnPdtE5jNeJOyA3T6Doeft7EVCnnSnjlK1Qx9X9J+0Ix0LGXG7IafLYKs3 +9+M7vHE6ja6/urWqbXRDjoo9YlC/6mBn858v4QOPLTuflBVIoeVvH7adPhDQSA5JBoJNj CfdiO0AF8jFbCAAV3KPpiU+6UIvbapTgxiENwDo9IHAN1/yG/EF++yId0z4lxIe+/z9V4m GToKadnEZjR/EG6mi/bvPfIQ7ycCDyT7o9pSJbUxJR+4YwX+S0Hws6y1ewQ9dKzO7JeI6/ gNqvI5dSJKftuZGopwQWUm6w7ppk/QAAAAMBAAEAAAGAHKnC+Nq0XtGAkIFE4N18e6SAwy 0WSWaZqmCzFQM0S2AhJnweOIG/0ZZHjsRzKKauOTmppQk40dgVsejpytIek9R+aH172gxJ 2n4Cx0UwduRU5x8FFQlNc/kl722B0JWfJuB/snOZXv6LJ4o5aObIkozt2w9tVFeAqjYn2S 1UsNOfRHBXGsTYwpRDwFWP56nKo2d2wBBTHDhCy6fb2dLW1fvSi/YspueOGIlHpvlYKi2/ CWqvs9xVrwcScMtiDoQYq0khhO0efLCxvg/o+W9CLMVM2ms4G1zoSUQKN0oYWWQJyW4+VI YneWO8UpN0J3ElXKi7bhgAat7dBaM1g9IrAzk153DiEFZNsPxGOgL/+YdQN7zUBx/z7EkI jyv80RV7fpUXvcq2p+qNl6UVig3VSzRrnsaJkUWu/A0u59ha7ocv6NxDIXjxpIDJme16GF quiGVBQNnYJymS/vFEbGf6bgf7iRmMCRUMG4nqLA6fPYP9uAtch+CmDfVLZC/fIdC5AAAA wQCDissV4zH6bfqgxJSuYNk8Vbb+19cF3b7gH1rVlB3zxpCAgcRgMHC+dP1z2NRx7UW9MR nye6kjpkzZZ0OigLqo7TtEq8uTglD9o6W7mRXqhy5A/ySOmqPL3ernHHQhGuoNODYAHkOU u2Rh8HXi+VLwKZcLInPOYJvcuLG4DxN8WfeVvlMHwhAOaTNNOtL4XZDHQeIPc4qHmJymmv sV7GuyQ6yW5C10uoGdxRPd90Bh4z4h2bKfZFjvEBbSBVkqrlAAAADBAN/zNtNayd/dX7Cr Nb4sZuzCh+CW4BH8GOePZWNCATwBbNXBVb5cR+dmuTqYm+Ekz0VxVQRA1TvKncluJOQpoa Xj8r0xdIgqkehnfDPMKtYVor06B9Fl1jrXtXU0Vrr6QcBWruSVyK1ZxqcmcNK/+KolVepe A6vcl/iKaG4U7su166nxLST06M2EgcSVsFJHpKn5+WAXC+X0Gx8kNjWIIb3GpiChdc0xZD mq02xZthVJrTCVw/e7gfDoB2QRsNV8HwAAAMEAzsCghZVp+0YsYg9oOrw4tEqcbEXEMhwY 0jW8JNL8Spr1Ibp5Dw6bRSk5azARjmJtnMJhJ3oeHfF0eoISqcNuQXGndGQbVM9YzzAzc1 NbbCNsVroqKlChT5wyPNGS+phi2bPARBno7WSDvshTZ7dAVEP2c9MJW0XwoSevwKlhgSdt RLFFQ/5nclJSdzPBOmQouC0OBcMFSrYtMeknJ4VvueVvve5HcHFaEsaMc7ABAGaLYaBQOm iixITGvaNZh/tjAAAACW5pY29sYUBwMQE= -----END OPENSSH PRIVATE KEY-----` ) var ( configDir = filepath.Join(".", "..", "..") allPerms = []string{dataprovider.PermAny} homeBasePath string logFilePath string backupsPath string testFileContent = []byte("test data") lastReceivedEmail receivedEmail ) func TestMain(m *testing.M) { homeBasePath = os.TempDir() logFilePath = filepath.Join(configDir, "common_test.log") backupsPath = filepath.Join(os.TempDir(), "backups") logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel) os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") err := config.LoadConfig(configDir, "") if err != nil { logger.ErrorToConsole("error loading configuration: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() providerConf.BackupsPath = backupsPath logger.InfoToConsole("Starting COMMON tests, provider: %v", providerConf.Driver) err = dataprovider.Initialize(providerConf, configDir, true) if err != nil { logger.ErrorToConsole("error initializing data provider: %v", err) os.Exit(1) } err = common.Initialize(config.GetCommonConfig(), 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) } httpConfig := config.GetHTTPConfig() httpConfig.Timeout = 5 httpConfig.RetryMax = 0 httpConfig.Initialize(configDir) //nolint:errcheck kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing kms: %v", err) os.Exit(1) } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing MFA: %v", err) os.Exit(1) } sftpdConf := config.GetSFTPDConfig() sftpdConf.Bindings[0].Port = 4022 sftpdConf.EnabledSSHCommands = []string{"*"} sftpdConf.Bindings = append(sftpdConf.Bindings, sftpd.Binding{ Port: 4024, }) sftpdConf.KeyboardInteractiveAuthentication = true httpdConf := config.GetHTTPDConfig() httpdConf.Bindings[0].Port = 4080 httpdtest.SetBaseURL("http://127.0.0.1:4080") webDavConf := config.GetWebDAVDConfig() webDavConf.Bindings = []webdavd.Binding{ { Port: webDavServerPort, }, } go func() { if err := sftpdConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start SFTP server: %v", err) os.Exit(1) } }() go func() { if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } }() go func() { if err := webDavConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start WebDAV server: %v", err) os.Exit(1) } }() waitTCPListening(sftpdConf.Bindings[0].GetAddress()) waitTCPListening(httpdConf.Bindings[0].GetAddress()) waitTCPListening(webDavConf.Bindings[0].GetAddress()) startHTTPFs() go func() { // start a test HTTP server to receive action notifications http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintf(w, "OK\n") }) http.HandleFunc("/404", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, "Not found\n") }) http.HandleFunc("/multipart", func(w http.ResponseWriter, r *http.Request) { err := r.ParseMultipartForm(1048576) if err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "KO\n") return } defer r.MultipartForm.RemoveAll() //nolint:errcheck fmt.Fprintf(w, "OK\n") }) if err := http.ListenAndServe(httpAddr, nil); err != nil { logger.ErrorToConsole("could not start HTTP notification server: %v", err) os.Exit(1) } }() go func() { common.Config.ProxyProtocol = 2 listener, err := net.Listen("tcp", httpProxyAddr) if err != nil { logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err) os.Exit(1) } proxyListener, err := common.Config.GetProxyListener(listener) if err != nil { logger.ErrorToConsole("error creating proxy protocol listener: %v", err) os.Exit(1) } common.Config.ProxyProtocol = 0 s := &http.Server{} if err := s.Serve(proxyListener); err != nil { logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err) os.Exit(1) } }() go func() { if err := smtpd.ListenAndServe(smtpServerAddr, func(_ net.Addr, from string, to []string, data []byte) error { lastReceivedEmail.set(from, to, data) return nil }, "SFTPGo test", "localhost"); err != nil { logger.ErrorToConsole("could not start SMTP server: %v", err) os.Exit(1) } }() waitTCPListening(httpAddr) waitTCPListening(httpProxyAddr) waitTCPListening(smtpServerAddr) exitCode := m.Run() os.Remove(logFilePath) os.RemoveAll(backupsPath) os.Exit(exitCode) } func TestBaseConnection(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) _, err = client.ReadDir(testDir) assert.ErrorIs(t, err, os.ErrNotExist) err = client.RemoveDirectory(testDir) assert.ErrorIs(t, err, os.ErrNotExist) err = client.Mkdir(testDir) assert.NoError(t, err) err = client.Mkdir(testDir) assert.Error(t, err) info, err := client.Stat(testDir) if assert.NoError(t, err) { assert.True(t, info.IsDir()) } err = client.Rename(testDir, testDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "the rename source and target cannot be the same") } err = client.Rename(testDir, path.Join(testDir, "sub")) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.RemoveDirectory(testDir) assert.NoError(t, err) err = client.Remove(testFileName) assert.ErrorIs(t, err, os.ErrNotExist) f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) linkName := testFileName + ".link" //nolint:goconst err = client.Rename(testFileName, testFileName) if assert.Error(t, err) { assert.Contains(t, err.Error(), "the rename source and target cannot be the same") } err = client.Symlink(testFileName, linkName) assert.NoError(t, err) err = client.Symlink(testFileName, testFileName) assert.Error(t, err) info, err = client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, int64(len(testFileContent)), info.Size()) assert.False(t, info.IsDir()) } info, err = client.Lstat(linkName) if assert.NoError(t, err) { assert.NotEqual(t, int64(7), info.Size()) assert.True(t, info.Mode()&os.ModeSymlink != 0) assert.False(t, info.IsDir()) } err = client.RemoveDirectory(linkName) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") } err = client.Remove(testFileName) assert.NoError(t, err) err = client.Remove(linkName) assert.NoError(t, err) err = client.Rename(testFileName, "test") assert.ErrorIs(t, err, os.ErrNotExist) f, err = client.Create(testFileName) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) err = client.Rename(testFileName, testFileName+"1") assert.NoError(t, err) err = client.Remove(testFileName + "1") assert.NoError(t, err) err = client.RemoveDirectory("missing") assert.Error(t, err) } else { printLatestLogs(10) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRemoveAll(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webDavClient := getWebDavClient(user) err = webDavClient.RemoveAll("/") if assert.Error(t, err) { assert.True(t, gowebdav.IsErrCode(err, http.StatusForbidden)) } testDir := "baseDir" err = webDavClient.RemoveAll(testDir) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir(testDir) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, testFileName), 1234, client) assert.NoError(t, err) err = webDavClient.RemoveAll(path.Join(testDir, testFileName)) assert.NoError(t, err) _, err = client.Stat(path.Join(testDir, testFileName)) assert.Error(t, err) err = writeSFTPFile(path.Join(testDir, testFileName), 1234, client) assert.NoError(t, err) err = webDavClient.RemoveAll(testDir) assert.NoError(t, err) _, err = client.Stat(testDir) assert.Error(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRelativeSymlinks(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() linkName := testFileName + "_link" //nolint:goconst err = client.Symlink("non-existent-file", linkName) assert.NoError(t, err) err = client.Remove(linkName) assert.NoError(t, err) testDir := "sub" err = client.Mkdir(testDir) assert.NoError(t, err) f, err := client.Create(path.Join(testDir, testFileName)) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) err = client.Symlink(path.Join(testDir, testFileName), linkName) assert.NoError(t, err) _, err = client.Stat(linkName) assert.NoError(t, err) p, err := client.ReadLink(linkName) assert.NoError(t, err) assert.Equal(t, path.Join("/", testDir, testFileName), p) err = client.Remove(linkName) assert.NoError(t, err) err = client.Symlink(testFileName, path.Join(testDir, linkName)) assert.NoError(t, err) _, err = client.Stat(path.Join(testDir, linkName)) assert.NoError(t, err) p, err = client.ReadLink(path.Join(testDir, linkName)) assert.NoError(t, err) assert.Equal(t, path.Join("/", testDir, testFileName), p) f, err = client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) err = client.Symlink(testFileName, linkName) assert.NoError(t, err) _, err = client.Stat(linkName) assert.NoError(t, err) p, err = client.ReadLink(linkName) assert.NoError(t, err) assert.Equal(t, path.Join("/", testFileName), p) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestCheckFsAfterUpdate(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } // remove the home dir, it will not be re-created err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.Error(t, err) } else { printLatestLogs(10) } // update the user and login again, this time the home dir will be created _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginAccessTime(t *testing.T) { u := getTestUser() u.Filters.AccessTime = []sdk.TimePeriod{ { DayOfWeek: int(time.Now().Add(-25 * time.Hour).UTC().Weekday()), From: "00:00", To: "23:59", }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) _, _, err = getSftpClient(user) assert.Error(t, err) user.Filters.AccessTime = []sdk.TimePeriod{ { DayOfWeek: int(time.Now().UTC().Weekday()), From: "00:00", To: "23:59", }, } user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err := checkBasicSFTP(client) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSetStat(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) acmodTime := time.Now().Add(36 * time.Hour) err = client.Chtimes(testFileName, acmodTime, acmodTime) assert.NoError(t, err) newFi, err := client.Lstat(testFileName) assert.NoError(t, err) diff := math.Abs(newFi.ModTime().Sub(acmodTime).Seconds()) assert.LessOrEqual(t, diff, float64(1)) if runtime.GOOS != osWindows { err = client.Chown(testFileName, os.Getuid(), os.Getgid()) assert.NoError(t, err) } newPerm := os.FileMode(0666) err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) newFi, err = client.Lstat(testFileName) if assert.NoError(t, err) { assert.Equal(t, newPerm, newFi.Mode().Perm()) } err = client.Truncate(testFileName, 2) assert.NoError(t, err) info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, int64(2), info.Size()) } err = client.Remove(testFileName) assert.NoError(t, err) err = client.Truncate(testFileName, 0) assert.ErrorIs(t, err, os.ErrNotExist) err = client.Chtimes(testFileName, acmodTime, acmodTime) assert.ErrorIs(t, err, os.ErrNotExist) if runtime.GOOS != osWindows { err = client.Chown(testFileName, os.Getuid(), os.Getgid()) assert.ErrorIs(t, err, os.ErrNotExist) } err = client.Chmod(testFileName, newPerm) assert.ErrorIs(t, err, os.ErrNotExist) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestCryptFsUserUploadErrorOverwrite(t *testing.T) { u := getCryptFsUser() u.QuotaSize = 6000 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) var buf []byte for i := 0; i < 4000; i++ { buf = append(buf, []byte("a")...) } bufSize := int64(len(buf)) reader := bytes.NewReader(buf) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.Create(testFileName + "_big") assert.NoError(t, err) n, err := io.Copy(f, reader) assert.NoError(t, err) assert.Equal(t, bufSize, n) err = f.Close() assert.NoError(t, err) encryptedSize, err := getEncryptedFileSize(bufSize) assert.NoError(t, err) expectedSize := encryptedSize user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, expectedSize, user.UsedQuotaSize) // now write a small file f, err = client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) encryptedSize, err = getEncryptedFileSize(int64(len(testFileContent))) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, expectedSize+encryptedSize, user.UsedQuotaSize) // try to overwrite this file with a big one, this cause an overquota error // the partial file is deleted and the quota updated _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) f, err = client.Create(testFileName) assert.NoError(t, err) _, err = io.Copy(f, reader) assert.Error(t, err) err = f.Close() assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, expectedSize, user.UsedQuotaSize) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestChtimesOpenHandle(t *testing.T) { localUser, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) u := getCryptFsUser() cryptFsUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser, cryptFsUser} { conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.Create(testFileName) assert.NoError(t, err, "user %v", user.Username) f1, err := client.Create(testFileName + "1") assert.NoError(t, err, "user %v", user.Username) acmodTime := time.Now().Add(36 * time.Hour) err = client.Chtimes(testFileName, acmodTime, acmodTime) assert.NoError(t, err, "user %v", user.Username) _, err = f.Write(testFileContent) assert.NoError(t, err, "user %v", user.Username) err = f.Close() assert.NoError(t, err, "user %v", user.Username) err = f1.Close() assert.NoError(t, err, "user %v", user.Username) info, err := client.Lstat(testFileName) assert.NoError(t, err, "user %v", user.Username) diff := math.Abs(info.ModTime().Sub(acmodTime).Seconds()) assert.LessOrEqual(t, diff, float64(1), "user %v", user.Username) info1, err := client.Lstat(testFileName + "1") assert.NoError(t, err, "user %v", user.Username) diff = math.Abs(info1.ModTime().Sub(acmodTime).Seconds()) assert.Greater(t, diff, float64(86400), "user %v", user.Username) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(cryptFsUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(cryptFsUser.GetHomeDir()) assert.NoError(t, err) } func TestWaitForConnections(t *testing.T) { u := getTestUser() u.UploadBandwidth = 128 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileSize := int64(524288) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = common.CheckClosing() assert.NoError(t, err) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() time.Sleep(1 * time.Second) common.WaitForTransfers(10) common.WaitForTransfers(0) common.WaitForTransfers(10) }() err = writeSFTPFileNoCheck(testFileName, testFileSize, client) assert.NoError(t, err) wg.Wait() err = common.CheckClosing() assert.EqualError(t, err, common.ErrShuttingDown.Error()) _, err = client.Stat(testFileName) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrShuttingDown.Error()) } } _, _, err = getSftpClient(user) assert.Error(t, err) err = common.Initialize(common.Config, 0) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, testFileSize, info.Size()) } err = client.Remove(testFileName) assert.NoError(t, err) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() time.Sleep(1 * time.Second) common.WaitForTransfers(1) }() err = writeSFTPFileNoCheck(testFileName, testFileSize, client) // we don't have an error here because the service won't really stop assert.NoError(t, err) wg.Wait() } err = common.Initialize(common.Config, 0) assert.NoError(t, err) common.WaitForTransfers(1) err = common.Initialize(common.Config, 0) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestCheckParentDirs(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) testDir := "/path/to/sub/dir" conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() _, err = client.Stat(testDir) assert.ErrorIs(t, err, os.ErrNotExist) c := common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", user) err = c.CheckParentDirs(testDir) assert.NoError(t, err) _, err = client.Stat(testDir) assert.NoError(t, err) err = c.CheckParentDirs(testDir) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) u := getTestUser() u.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermDownload} user, _, err = httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() c := common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", user) err = c.CheckParentDirs(testDir) assert.ErrorIs(t, err, sftp.ErrSSHFxPermissionDenied) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermissionErrors(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) u := getTestSFTPUser() subDir := "/sub" u.Permissions[subDir] = nil sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.MkdirAll(path.Join(subDir, subDir)) assert.NoError(t, err) f, err := client.Create(path.Join(subDir, subDir, testFileName)) if assert.NoError(t, err) { _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) } } conn, client, err = getSftpClient(sftpUser) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) _, err = client.ReadDir(subDir) assert.ErrorIs(t, err, os.ErrPermission) err = client.Mkdir(path.Join(subDir, subDir)) assert.ErrorIs(t, err, os.ErrPermission) err = client.RemoveDirectory(path.Join(subDir, subDir)) assert.ErrorIs(t, err, os.ErrPermission) err = client.Symlink("test", path.Join(subDir, subDir)) assert.ErrorIs(t, err, os.ErrPermission) err = client.Chmod(path.Join(subDir, subDir), os.ModePerm) assert.ErrorIs(t, err, os.ErrPermission) err = client.Chown(path.Join(subDir, subDir), os.Getuid(), os.Getgid()) assert.ErrorIs(t, err, os.ErrPermission) err = client.Chtimes(path.Join(subDir, subDir), time.Now(), time.Now()) assert.ErrorIs(t, err, os.ErrPermission) err = client.Truncate(path.Join(subDir, subDir), 0) assert.ErrorIs(t, err, os.ErrPermission) err = client.Remove(path.Join(subDir, subDir, testFileName)) assert.ErrorIs(t, err, os.ErrPermission) } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestHiddenPatternFilter(t *testing.T) { deniedDir := "/denied_hidden" u := getTestUser() u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: deniedDir, DeniedPatterns: []string{"*.txt", "beta*"}, DenyPolicy: sdk.DenyPolicyHide, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) dirName := "beta" subDirName := "testDir" testFile := filepath.Join(u.GetHomeDir(), deniedDir, "file.txt") testFile1 := filepath.Join(u.GetHomeDir(), deniedDir, "beta.txt") testHiddenFile := filepath.Join(u.GetHomeDir(), deniedDir, dirName, subDirName, "hidden.jpg") err = os.MkdirAll(filepath.Join(u.GetHomeDir(), deniedDir), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFile, testFileContent, os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFile1, testFileContent, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(filepath.Join(u.GetHomeDir(), deniedDir, dirName, subDirName), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testHiddenFile, testFileContent, os.ModePerm) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() files, err := client.ReadDir(deniedDir) assert.NoError(t, err) assert.Len(t, files, 0) err = client.Remove(path.Join(deniedDir, filepath.Base(testFile))) assert.ErrorIs(t, err, os.ErrNotExist) err = client.Chtimes(path.Join(deniedDir, filepath.Base(testFile)), time.Now(), time.Now()) assert.ErrorIs(t, err, os.ErrNotExist) _, err = client.Stat(path.Join(deniedDir, filepath.Base(testFile1))) assert.ErrorIs(t, err, os.ErrNotExist) err = client.RemoveDirectory(path.Join(deniedDir, dirName)) assert.ErrorIs(t, err, os.ErrNotExist) err = client.Rename(path.Join(deniedDir, dirName), path.Join(deniedDir, "newname")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Mkdir(path.Join(deniedDir, "beta1")) assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile(path.Join(deniedDir, "afile.txt"), 1024, client) assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile(path.Join(deniedDir, dirName, subDirName, "afile.jpg"), 1024, client) assert.ErrorIs(t, err, os.ErrPermission) _, err = client.Open(path.Join(deniedDir, dirName, subDirName, filepath.Base(testHiddenFile))) assert.ErrorIs(t, err, os.ErrNotExist) err = client.Symlink(path.Join(deniedDir, dirName), dirName) assert.ErrorIs(t, err, os.ErrNotExist) err = writeSFTPFile(path.Join(deniedDir, testFileName), 1024, client) assert.NoError(t, err) err = client.Symlink(path.Join(deniedDir, testFileName), path.Join(deniedDir, "symlink.txt")) assert.ErrorIs(t, err, os.ErrPermission) files, err = client.ReadDir(deniedDir) assert.NoError(t, err) assert.Len(t, files, 1) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: deniedDir, DeniedPatterns: []string{"*.txt", "beta*"}, DenyPolicy: sdk.DenyPolicyDefault, }, } user, _, err = httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() files, err := client.ReadDir(deniedDir) assert.NoError(t, err) assert.Len(t, files, 4) _, err = client.Stat(path.Join(deniedDir, filepath.Base(testFile))) assert.NoError(t, err) err = client.Chtimes(path.Join(deniedDir, filepath.Base(testFile)), time.Now(), time.Now()) assert.ErrorIs(t, err, os.ErrPermission) err = client.Mkdir(path.Join(deniedDir, "beta2")) assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile(path.Join(deniedDir, "afile2.txt"), 1024, client) assert.ErrorIs(t, err, os.ErrPermission) err = client.Symlink(path.Join(deniedDir, testFileName), path.Join(deniedDir, "link.txt")) assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile(path.Join(deniedDir, dirName, subDirName, "afile.jpg"), 1024, client) assert.NoError(t, err) f, err := client.Open(path.Join(deniedDir, dirName, subDirName, filepath.Base(testHiddenFile))) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestHiddenRoot(t *testing.T) { // only the "/ftp" directory is allowed and visibile in the "/" path // within /ftp any file/directory is allowed and visibile u := getTestUser() u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", AllowedPatterns: []string{"ftp"}, DenyPolicy: sdk.DenyPolicyHide, }, { Path: "/ftp", AllowedPatterns: []string{"*"}, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for i := 0; i < 10; i++ { err = os.MkdirAll(filepath.Join(user.HomeDir, fmt.Sprintf("ftp%d", i)), os.ModePerm) assert.NoError(t, err) } err = os.WriteFile(filepath.Join(user.HomeDir, testFileName), []byte(""), 0666) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.HomeDir, "ftp.txt"), []byte(""), 0666) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("ftp") assert.NoError(t, err) entries, err := client.ReadDir("/") assert.NoError(t, err) if assert.Len(t, entries, 1) { assert.Equal(t, "ftp", entries[0].Name()) } _, err = client.Stat(".") assert.NoError(t, err) for _, name := range []string{testFileName, "ftp.txt"} { _, err = client.Stat(name) assert.ErrorIs(t, err, os.ErrNotExist) } for i := 0; i < 10; i++ { _, err = client.Stat(fmt.Sprintf("ftp%d", i)) assert.ErrorIs(t, err, os.ErrNotExist) } err = writeSFTPFile(testFileName, 4096, client) assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile("ftp123", 4096, client) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(testFileName, testFileName+"_rename") //nolint:goconst assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile(path.Join("/ftp", testFileName), 4096, client) assert.NoError(t, err) err = client.Mkdir("/ftp/dir") assert.NoError(t, err) err = client.Rename(path.Join("/ftp", testFileName), path.Join("/ftp/dir", testFileName)) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestFileNotAllowedErrors(t *testing.T) { deniedDir := "/denied" u := getTestUser() u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: deniedDir, DeniedPatterns: []string{"*.txt"}, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFile := filepath.Join(u.GetHomeDir(), deniedDir, "file.txt") err = os.MkdirAll(filepath.Join(u.GetHomeDir(), deniedDir), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFile, testFileContent, os.ModePerm) assert.NoError(t, err) err = client.Remove(path.Join(deniedDir, "file.txt")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join(deniedDir, "file.txt"), path.Join(deniedDir, "file1.txt")) assert.ErrorIs(t, err, os.ErrPermission) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRootDirVirtualFolder(t *testing.T) { mappedPath1 := filepath.Join(os.TempDir(), "mapped1") f1 := vfs.BaseVirtualFolder{ Name: filepath.Base(mappedPath1), MappedPath: mappedPath1, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret("cryptsecret"), }, }, } mappedPath2 := filepath.Join(os.TempDir(), "mapped2") f2 := vfs.BaseVirtualFolder{ Name: filepath.Base(mappedPath2), MappedPath: mappedPath2, } folder1, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) folder2, _, err := httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.QuotaFiles = 1000 u.UploadDataTransfer = 1000 u.DownloadDataTransfer = 5000 u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder1.Name, }, VirtualPath: "/", QuotaFiles: 1000, }) vdirPath2 := "/vmapped" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder2.Name, }, VirtualPath: vdirPath2, QuotaFiles: -1, QuotaSize: -1, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) f, err := user.GetVirtualFolderForPath("/") assert.NoError(t, err) assert.Equal(t, "/", f.VirtualPath) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) f, err := client.Create(testFileName) if assert.NoError(t, err) { _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) } assert.NoFileExists(t, filepath.Join(user.HomeDir, testFileName)) assert.FileExists(t, filepath.Join(mappedPath1, testFileName)) entries, err := client.ReadDir(".") if assert.NoError(t, err) { assert.Len(t, entries, 2) } user, _, err := httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) folder, _, err := httpdtest.GetFolderByName(folder1.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, folder.UsedQuotaFiles) f, err = client.Create(path.Join(vdirPath2, testFileName)) if assert.NoError(t, err) { _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) folder, _, err = httpdtest.GetFolderByName(folder1.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, folder.UsedQuotaFiles) err = client.Rename(testFileName, path.Join(vdirPath2, testFileName+"_rename")) assert.Error(t, err) err = client.Rename(path.Join(vdirPath2, testFileName), testFileName+"_rename") assert.Error(t, err) err = client.Rename(testFileName, testFileName+"_rename") assert.NoError(t, err) err = client.Rename(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+"_rename")) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folder1.Name}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folder2.Name}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestTruncateQuotaLimits(t *testing.T) { mappedPath1 := filepath.Join(os.TempDir(), "mapped1") f1 := vfs.BaseVirtualFolder{ Name: filepath.Base(mappedPath1), MappedPath: mappedPath1, } folder1, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) mappedPath2 := filepath.Join(os.TempDir(), "mapped2") f2 := vfs.BaseVirtualFolder{ Name: filepath.Base(mappedPath2), MappedPath: mappedPath2, } folder2, _, err := httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.QuotaSize = 20 u.UploadDataTransfer = 1000 u.DownloadDataTransfer = 5000 vdirPath1 := "/vmapped1" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder1.Name, }, VirtualPath: vdirPath1, QuotaFiles: 10, }) vdirPath2 := "/vmapped2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder2.Name, }, VirtualPath: vdirPath2, QuotaFiles: -1, QuotaSize: -1, }) localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaSize = 20 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.OpenFile(testFileName, os.O_WRONLY|os.O_CREATE) if assert.NoError(t, err) { n, err := f.Write(testFileContent) assert.NoError(t, err) assert.Equal(t, len(testFileContent), n) err = f.Truncate(2) assert.NoError(t, err) expectedQuotaFiles := 0 expectedQuotaSize := int64(2) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) _, err = f.Seek(expectedQuotaSize, io.SeekStart) assert.NoError(t, err) n, err = f.Write(testFileContent) assert.NoError(t, err) assert.Equal(t, len(testFileContent), n) err = f.Truncate(5) assert.NoError(t, err) expectedQuotaSize = int64(5) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) _, err = f.Seek(expectedQuotaSize, io.SeekStart) assert.NoError(t, err) n, err = f.Write(testFileContent) assert.NoError(t, err) assert.Equal(t, len(testFileContent), n) err = f.Close() assert.NoError(t, err) expectedQuotaFiles = 1 expectedQuotaSize = int64(5) + int64(len(testFileContent)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) } // now truncate by path err = client.Truncate(testFileName, 5) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(5), user.UsedQuotaSize) // now open an existing file without truncate it, quota should not change f, err = client.OpenFile(testFileName, os.O_WRONLY) if assert.NoError(t, err) { err = f.Close() assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(5), user.UsedQuotaSize) } // open the file truncating it f, err = client.OpenFile(testFileName, os.O_WRONLY|os.O_TRUNC) if assert.NoError(t, err) { err = f.Close() assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) } // now test max write size f, err = client.OpenFile(testFileName, os.O_WRONLY) if assert.NoError(t, err) { n, err := f.Write(testFileContent) assert.NoError(t, err) assert.Equal(t, len(testFileContent), n) err = f.Truncate(11) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(11), user.UsedQuotaSize) _, err = f.Seek(int64(11), io.SeekStart) assert.NoError(t, err) n, err = f.Write(testFileContent) assert.NoError(t, err) assert.Equal(t, len(testFileContent), n) err = f.Truncate(5) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(5), user.UsedQuotaSize) _, err = f.Seek(int64(5), io.SeekStart) assert.NoError(t, err) n, err = f.Write(testFileContent) assert.NoError(t, err) assert.Equal(t, len(testFileContent), n) err = f.Truncate(12) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(12), user.UsedQuotaSize) _, err = f.Seek(int64(12), io.SeekStart) assert.NoError(t, err) _, err = f.Write(testFileContent) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } err = f.Close() assert.Error(t, err) // the file is deleted user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) } if user.Username == defaultUsername { // basic test inside a virtual folder vfileName1 := path.Join(vdirPath1, testFileName) f, err = client.OpenFile(vfileName1, os.O_WRONLY|os.O_CREATE) if assert.NoError(t, err) { n, err := f.Write(testFileContent) assert.NoError(t, err) assert.Equal(t, len(testFileContent), n) err = f.Truncate(2) assert.NoError(t, err) expectedQuotaFiles := 0 expectedQuotaSize := int64(2) fold, _, err := httpdtest.GetFolderByName(folder1.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) err = f.Close() assert.NoError(t, err) expectedQuotaFiles = 1 fold, _, err = httpdtest.GetFolderByName(folder1.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) } err = client.Truncate(vfileName1, 1) assert.NoError(t, err) fold, _, err := httpdtest.GetFolderByName(folder1.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(1), fold.UsedQuotaSize) assert.Equal(t, 1, fold.UsedQuotaFiles) // now test on vdirPath2, the folder quota is included in the user's quota vfileName2 := path.Join(vdirPath2, testFileName) f, err = client.OpenFile(vfileName2, os.O_WRONLY|os.O_CREATE) if assert.NoError(t, err) { n, err := f.Write(testFileContent) assert.NoError(t, err) assert.Equal(t, len(testFileContent), n) err = f.Truncate(3) assert.NoError(t, err) expectedQuotaFiles := 0 expectedQuotaSize := int64(3) fold, _, err := httpdtest.GetFolderByName(folder2.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), fold.UsedQuotaSize) assert.Equal(t, 0, fold.UsedQuotaFiles) err = f.Close() assert.NoError(t, err) expectedQuotaFiles = 1 fold, _, err = httpdtest.GetFolderByName(folder2.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), fold.UsedQuotaSize) assert.Equal(t, 0, fold.UsedQuotaFiles) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) } // cleanup err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) if user.Username == defaultUsername { _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.QuotaSize = 0 user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder2, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) { testFileSize := int64(131072) testFileSize1 := int64(65537) testFileName1 := "test_file1.dat" //nolint:goconst u := getTestUser() u.QuotaFiles = 0 u.QuotaSize = 0 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" //nolint:goconst mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" //nolint:goconst mappedPath3 := filepath.Join(os.TempDir(), "vdir3") folderName3 := filepath.Base(mappedPath3) vdirPath3 := "/vdir3" f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) f3 := vfs.BaseVirtualFolder{ Name: folderName3, MappedPath: mappedPath3, } _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: 2, QuotaSize: 0, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 0, QuotaSize: testFileSize + testFileSize1 + 1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName3, }, VirtualPath: vdirPath3, QuotaFiles: 2, QuotaSize: testFileSize * 2, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) f, err := client.Open(path.Join(vdirPath1, testFileName)) assert.NoError(t, err) contents, err := io.ReadAll(f) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) assert.Len(t, contents, int(testFileSize)) err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(testFileName1, testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath3, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath3, testFileName+"1"), testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName, path.Join(vdirPath1, testFileName+".rename")) //nolint:goconst assert.Error(t, err) // we overwrite an existing file and we have unlimited size err = client.Rename(testFileName, path.Join(vdirPath1, testFileName)) assert.NoError(t, err) // we have no space and we try to overwrite a bigger file with a smaller one, this should succeed err = client.Rename(testFileName1, path.Join(vdirPath2, testFileName)) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) // we have no space and we try to overwrite a smaller file with a bigger one, this should fail err = client.Rename(testFileName, path.Join(vdirPath2, testFileName1)) assert.Error(t, err) fi, err := client.Stat(path.Join(vdirPath1, testFileName1)) if assert.NoError(t, err) { assert.Equal(t, testFileSize1, fi.Size()) } // we are overquota inside vdir3 size 2/2 and size 262144/262144 err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName1+".rename")) assert.Error(t, err) // we overwrite an existing file and we have enough size err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName)) assert.NoError(t, err) testFileName2 := "test_file2.dat" err = writeSFTPFile(testFileName2, testFileSize+testFileSize1, client) assert.NoError(t, err) // we overwrite an existing file and we haven't enough size err = client.Rename(testFileName2, path.Join(vdirPath3, testFileName)) assert.Error(t, err) // now remove a file from vdir3, create a dir with 2 files and try to rename it in vdir3 // this will fail since the rename will result in 3 files inside vdir3 and quota limits only // allow 2 total files there err = client.Remove(path.Join(vdirPath3, testFileName+"1")) assert.NoError(t, err) aDir := "a dir" err = client.Mkdir(aDir) assert.NoError(t, err) err = writeSFTPFile(path.Join(aDir, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(aDir, testFileName1+"1"), testFileSize1, client) assert.NoError(t, err) err = client.Rename(aDir, path.Join(vdirPath3, aDir)) assert.Error(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName3}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) err = os.RemoveAll(mappedPath3) assert.NoError(t, err) } func TestQuotaRenameOverwrite(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(131072) testFileSize1 := int64(65537) testFileName1 := "test_file1.dat" err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) f, err := client.Open(testFileName) assert.NoError(t, err) contents := make([]byte, testFileSize) n, err := io.ReadFull(f, contents) assert.NoError(t, err) assert.Equal(t, int(testFileSize), n) err = f.Close() assert.NoError(t, err) err = writeSFTPFile(testFileName1, testFileSize1, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) assert.Equal(t, int64(0), user.UsedUploadDataTransfer) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) err = client.Rename(testFileName, testFileName1) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) assert.Equal(t, int64(0), user.UsedUploadDataTransfer) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) err = client.Remove(testFileName1) assert.NoError(t, err) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(testFileName1, testFileSize1, client) assert.NoError(t, err) err = client.Rename(testFileName1, testFileName) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestVirtualFoldersQuotaValues(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") vdirPath1 := "/vdir1" folderName1 := filepath.Base(mappedPath1) mappedPath2 := filepath.Join(os.TempDir(), "vdir2") vdirPath2 := "/vdir2" folderName2 := filepath.Base(mappedPath2) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(131072) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) // we copy the same file two times to test quota update on file overwrite err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) expectedQuotaFiles := 2 expectedQuotaSize := testFileSize * 2 user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = client.Remove(path.Join(vdirPath1, testFileName)) assert.NoError(t, err) err = client.Remove(path.Join(vdirPath2, testFileName)) assert.NoError(t, err) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") vdirPath1 := "/vdir1" folderName1 := filepath.Base(mappedPath1) mappedPath2 := filepath.Join(os.TempDir(), "vdir2") vdirPath2 := "/vdir2" folderName2 := filepath.Base(mappedPath2) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName1 := "test_file1.dat" testFileSize := int64(131072) testFileSize1 := int64(65535) dir1 := "dir1" //nolint:goconst dir2 := "dir2" //nolint:goconst assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) // initial files: // - vdir1/dir1/testFileName // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 // // rename a file inside vdir1 it is included inside user quota, so we have: // - vdir1/dir1/testFileName.rename // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath1, dir1, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // rename a file inside vdir2, it isn't included inside user quota, so we have: // - vdir1/dir1/testFileName.rename // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName.rename // - vdir2/dir2/testFileName1 err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath2, dir1, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // rename a file inside vdir2 overwriting an existing, we now have: // - vdir1/dir1/testFileName.rename // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName.rename (initial testFileName1) err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // rename a file inside vdir1 overwriting an existing, we now have: // - vdir1/dir1/testFileName.rename (initial testFileName1) // - vdir2/dir1/testFileName.rename (initial testFileName1) err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath1, dir1, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // rename a directory inside the same virtual folder, quota should not change err = client.RemoveDirectory(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.RemoveDirectory(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = client.Rename(path.Join(vdirPath1, dir1), path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Rename(path.Join(vdirPath2, dir1), path.Join(vdirPath2, dir2)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestQuotaRenameBetweenVirtualFolder(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName1 := "test_file1.dat" testFileSize := int64(131072) testFileSize1 := int64(65535) dir1 := "dir1" dir2 := "dir2" err = client.Mkdir(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) // initial files: // - vdir1/dir1/testFileName // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 // // rename a file from vdir1 to vdir2, vdir1 is included inside user quota, so we have: // - vdir1/dir1/testFileName // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 // - vdir2/dir1/testFileName1.rename err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName1+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 3, f.UsedQuotaFiles) // rename a file from vdir2 to vdir1, vdir2 is not included inside user quota, so we have: // - vdir1/dir1/testFileName // - vdir1/dir2/testFileName.rename // - vdir2/dir2/testFileName1 // - vdir2/dir1/testFileName1.rename err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath1, dir2, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1*2, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) // rename a file from vdir1 to vdir2 overwriting an existing file, vdir1 is included inside user quota, so we have: // - vdir1/dir2/testFileName.rename // - vdir2/dir2/testFileName1 (is the initial testFileName) // - vdir2/dir1/testFileName1.rename err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath2, dir2, testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) // rename a file from vdir2 to vdir1 overwriting an existing file, vdir2 is not included inside user quota, so we have: // - vdir1/dir2/testFileName.rename (is the initial testFileName1) // - vdir2/dir2/testFileName1 (is the initial testFileName) err = client.Rename(path.Join(vdirPath2, dir1, testFileName1+".rename"), path.Join(vdirPath1, dir2, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName+"1.dupl"), testFileSize1, client) assert.NoError(t, err) err = client.RemoveDirectory(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.RemoveDirectory(path.Join(vdirPath2, dir1)) assert.NoError(t, err) // - vdir1/dir2/testFileName.rename (initial testFileName1) // - vdir1/dir2/testFileName // - vdir2/dir2/testFileName1 (initial testFileName) // - vdir2/dir2/testFileName (initial testFileName1) // - vdir2/dir2/testFileName1.dupl // rename directories between the two virtual folders err = client.Rename(path.Join(vdirPath2, dir2), path.Join(vdirPath1, dir1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 5, user.UsedQuotaFiles) assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // now move on vpath2 err = client.Rename(path.Join(vdirPath1, dir2), path.Join(vdirPath2, dir1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestQuotaRenameFromVirtualFolder(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName1 := "test_file1.dat" testFileSize := int64(131072) testFileSize1 := int64(65535) dir1 := "dir1" dir2 := "dir2" err = client.Mkdir(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) // initial files: // - vdir1/dir1/testFileName // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 // // rename a file from vdir1 to the user home dir, vdir1 is included in user quota so we have: // - testFileName // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(testFileName)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) // rename a file from vdir2 to the user home dir, vdir2 is not included in user quota so we have: // - testFileName // - testFileName1 // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // rename a file from vdir1 to the user home dir overwriting an existing file, vdir1 is included in user quota so we have: // - testFileName (initial testFileName1) // - testFileName1 // - vdir2/dir1/testFileName err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(testFileName)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize1+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // rename a file from vdir2 to the user home dir overwriting an existing file, vdir2 is not included in user quota so we have: // - testFileName (initial testFileName1) // - testFileName1 (initial testFileName) err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // dir rename err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName1), testFileSize1, client) assert.NoError(t, err) // - testFileName (initial testFileName1) // - testFileName1 (initial testFileName) // - vdir1/dir1/testFileName // - vdir1/dir1/testFileName1 // - dir1/testFileName // - dir1/testFileName1 err = client.Rename(path.Join(vdirPath2, dir1), dir1) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 6, user.UsedQuotaFiles) assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // - testFileName (initial testFileName1) // - testFileName1 (initial testFileName) // - dir2/testFileName // - dir2/testFileName1 // - dir1/testFileName // - dir1/testFileName1 err = client.Rename(path.Join(vdirPath1, dir1), dir2) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 6, user.UsedQuotaFiles) assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestQuotaRenameToVirtualFolder(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) u.Permissions[vdirPath1] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermOverwrite, dataprovider.PermDelete, dataprovider.PermCreateSymlinks, dataprovider.PermCreateDirs, dataprovider.PermRename} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName1 := "test_file1.dat" testFileSize := int64(131072) testFileSize1 := int64(65535) dir1 := "dir1" dir2 := "dir2" err = client.Mkdir(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(testFileName1, testFileSize1, client) assert.NoError(t, err) // initial files: // - testFileName // - testFileName1 // // rename a file from user home dir to vdir1, vdir1 is included in user quota so we have: // - testFileName // - /vdir1/dir1/testFileName1 err = client.Rename(testFileName1, path.Join(vdirPath1, dir1, testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have: // - /vdir2/dir1/testFileName // - /vdir1/dir1/testFileName1 err = client.Rename(testFileName, path.Join(vdirPath2, dir1, testFileName)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // upload two new files to the user home dir so we have: // - testFileName // - testFileName1 // - /vdir1/dir1/testFileName1 // - /vdir2/dir1/testFileName err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(testFileName1, testFileSize1, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) // rename a file from user home dir to vdir1 overwriting an existing file, vdir1 is included in user quota so we have: // - testFileName1 // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName err = client.Rename(testFileName, path.Join(vdirPath1, dir1, testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // rename a file from user home dir to vdir2 overwriting an existing file, vdir2 is not included in user quota so we have: // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName (initial testFileName1) err = client.Rename(testFileName1, path.Join(vdirPath2, dir1, testFileName)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = client.Mkdir(dir1) assert.NoError(t, err) err = writeSFTPFile(path.Join(dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(dir1, testFileName1), testFileSize1, client) assert.NoError(t, err) // - /dir1/testFileName // - /dir1/testFileName1 // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName (initial testFileName1) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // - /vdir1/adir/testFileName // - /vdir1/adir/testFileName1 // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName (initial testFileName1) err = client.Rename(dir1, path.Join(vdirPath1, "adir")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = client.Mkdir(dir1) assert.NoError(t, err) err = writeSFTPFile(path.Join(dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(dir1, testFileName1), testFileSize1, client) assert.NoError(t, err) // - /vdir1/adir/testFileName // - /vdir1/adir/testFileName1 // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName (initial testFileName1) // - /vdir2/adir/testFileName // - /vdir2/adir/testFileName1 err = client.Rename(dir1, path.Join(vdirPath2, "adir")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize) assert.Equal(t, 3, f.UsedQuotaFiles) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestTransferQuotaLimits(t *testing.T) { u := getTestUser() u.TotalDataTransfer = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(524288) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) f, err := client.Open(testFileName) assert.NoError(t, err) contents := make([]byte, testFileSize) n, err := io.ReadFull(f, contents) assert.NoError(t, err) assert.Equal(t, int(testFileSize), n) assert.Len(t, contents, int(testFileSize)) err = f.Close() assert.NoError(t, err) _, err = client.Open(testFileName) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) } err = writeSFTPFile(testFileName, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } } // test the limit while uploading/downloading user.TotalDataTransfer = 0 user.UploadDataTransfer = 1 user.DownloadDataTransfer = 1 _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(450000) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) f, err := client.Open(testFileName) if assert.NoError(t, err) { _, err = io.Copy(io.Discard, f) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) } f, err = client.Open(testFileName) if assert.NoError(t, err) { _, err = io.Copy(io.Discard, f) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) } err = f.Close() assert.Error(t, err) } err = writeSFTPFile(testFileName, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestVirtualFoldersLink(t *testing.T) { u := getTestUser() mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(131072) testDir := "adir" err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, testDir)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, testDir)) assert.NoError(t, err) err = client.Symlink(testFileName, testFileName+".link") assert.NoError(t, err) err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testFileName+".link")) assert.NoError(t, err) err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testDir, testFileName+".link")) assert.NoError(t, err) err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+".link")) assert.NoError(t, err) err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testDir, testFileName+".link")) assert.NoError(t, err) err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath1, testFileName+".link1")) //nolint:goconst if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath1, testDir, testFileName+".link1")) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath2, testFileName+".link1")) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink(path.Join(vdirPath1, testFileName), testFileName+".link1") if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink(path.Join(vdirPath2, testFileName), testFileName+".link1") if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath1, testFileName+".link1")) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink("/", "/roolink") assert.ErrorIs(t, err, os.ErrPermission) err = client.Symlink(testFileName, "/") assert.ErrorIs(t, err, os.ErrPermission) err = client.Symlink(testFileName, vdirPath1) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Symlink(vdirPath1, testFileName+".link2") if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestCrossFolderRename(t *testing.T) { folder1 := "folder1" folder2 := "folder2" folder3 := "folder3" folder4 := "folder4" folder5 := "folder5" folder6 := "folder6" folder7 := "folder7" baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err, string(resp)) f1 := vfs.BaseVirtualFolder{ Name: folder1, MappedPath: filepath.Join(os.TempDir(), folder1), FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folder2, MappedPath: filepath.Join(os.TempDir(), folder2), FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) f3 := vfs.BaseVirtualFolder{ Name: folder3, MappedPath: filepath.Join(os.TempDir(), folder3), FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword + "mod"), }, }, } _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) f4 := vfs.BaseVirtualFolder{ Name: folder4, MappedPath: filepath.Join(os.TempDir(), folder4), FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: baseUser.Username, Prefix: path.Join("/", folder4), }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f4, http.StatusCreated) assert.NoError(t, err) f5 := vfs.BaseVirtualFolder{ Name: folder5, MappedPath: filepath.Join(os.TempDir(), folder5), FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: baseUser.Username, Prefix: path.Join("/", folder5), }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f5, http.StatusCreated) assert.NoError(t, err) f6 := vfs.BaseVirtualFolder{ Name: folder6, MappedPath: filepath.Join(os.TempDir(), folder6), FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: "127.0.0.1:4024", Username: baseUser.Username, Prefix: path.Join("/", folder6), }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f6, http.StatusCreated) assert.NoError(t, err) f7 := vfs.BaseVirtualFolder{ Name: folder7, MappedPath: filepath.Join(os.TempDir(), folder7), FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: baseUser.Username, Prefix: path.Join("/", folder4), }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f7, http.StatusCreated) assert.NoError(t, err) u := getCryptFsUser() u.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder1, }, VirtualPath: path.Join("/", folder1), QuotaSize: -1, QuotaFiles: -1, }, { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder2, }, VirtualPath: path.Join("/", folder2), QuotaSize: -1, QuotaFiles: -1, }, { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder3, }, VirtualPath: path.Join("/", folder3), QuotaSize: -1, QuotaFiles: -1, }, { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder4, }, VirtualPath: path.Join("/", folder4), QuotaSize: -1, QuotaFiles: -1, }, { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder5, }, VirtualPath: path.Join("/", folder5), QuotaSize: -1, QuotaFiles: -1, }, { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder6, }, VirtualPath: path.Join("/", folder6), QuotaSize: -1, QuotaFiles: -1, }, { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder7, }, VirtualPath: path.Join("/", folder7), QuotaSize: -1, QuotaFiles: -1, }, } user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() subDir := "testSubDir" err = client.Mkdir(subDir) assert.NoError(t, err) err = writeSFTPFile(path.Join(subDir, "afile.bin"), 64, client) assert.NoError(t, err) err = client.Rename(subDir, path.Join("/", folder1, subDir)) assert.NoError(t, err) _, err = client.Stat(path.Join("/", folder1, subDir)) assert.NoError(t, err) _, err = client.Stat(path.Join("/", folder1, subDir, "afile.bin")) assert.NoError(t, err) err = client.Rename(path.Join("/", folder1, subDir), path.Join("/", folder2, subDir)) assert.NoError(t, err) _, err = client.Stat(path.Join("/", folder2, subDir)) assert.NoError(t, err) _, err = client.Stat(path.Join("/", folder2, subDir, "afile.bin")) assert.NoError(t, err) err = client.Rename(path.Join("/", folder2, subDir), path.Join("/", folder3, subDir)) assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile(path.Join("/", folder3, "file.bin"), 64, client) assert.NoError(t, err) err = client.Rename(path.Join("/", folder3, "file.bin"), "/renamed.bin") assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join("/", folder3, "file.bin"), path.Join("/", folder2, "/renamed.bin")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join("/", folder3, "file.bin"), path.Join("/", folder3, "/renamed.bin")) assert.NoError(t, err) err = writeSFTPFile("/afile.bin", 64, client) assert.NoError(t, err) err = client.Rename("afile.bin", path.Join("/", folder4, "afile_renamed.bin")) assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile(path.Join("/", folder4, "afile.bin"), 64, client) assert.NoError(t, err) err = client.Rename(path.Join("/", folder4, "afile.bin"), path.Join("/", folder5, "afile_renamed.bin")) assert.NoError(t, err) err = client.Rename(path.Join("/", folder5, "afile_renamed.bin"), path.Join("/", folder6, "afile_renamed.bin")) assert.ErrorIs(t, err, os.ErrPermission) err = writeSFTPFile(path.Join("/", folder4, "afile.bin"), 64, client) assert.NoError(t, err) _, err = client.Stat(path.Join("/", folder7, "afile.bin")) assert.NoError(t, err) err = client.Rename(path.Join("/", folder4, "afile.bin"), path.Join("/", folder7, "afile.bin")) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) for _, folderName := range []string{folder1, folder2, folder3, folder4, folder5, folder6, folder7} { _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(filepath.Join(os.TempDir(), folderName)) assert.NoError(t, err) } } func TestDirs(t *testing.T) { u := getTestUser() mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) vdirPath := "/path/vdir" f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, }) u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermCreateDirs, dataprovider.PermRename, dataprovider.PermListItems} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() info, err := client.ReadDir("/") if assert.NoError(t, err) { if assert.Len(t, info, 1) { assert.Equal(t, "path", info[0].Name()) } } fi, err := client.Stat(path.Dir(vdirPath)) if assert.NoError(t, err) { assert.True(t, fi.IsDir()) } err = client.RemoveDirectory("/") assert.ErrorIs(t, err, os.ErrPermission) err = client.RemoveDirectory(vdirPath) assert.ErrorIs(t, err, os.ErrPermission) err = client.RemoveDirectory(path.Dir(vdirPath)) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.Mkdir(vdirPath) assert.ErrorIs(t, err, os.ErrPermission) err = client.Mkdir("adir") assert.NoError(t, err) err = client.Rename("/adir", path.Dir(vdirPath)) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = client.MkdirAll("/subdir/adir") assert.NoError(t, err) err = client.Rename("adir", "subdir/adir") if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = writeSFTPFile("/subdir/afile.bin", 64, client) assert.NoError(t, err) err = writeSFTPFile("/afile.bin", 32, client) assert.NoError(t, err) err = client.Rename("afile.bin", "subdir/afile.bin") assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename("afile.bin", "subdir/afile1.bin") assert.NoError(t, err) err = client.Rename(path.Dir(vdirPath), "renamed_vdir") if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestCryptFsStat(t *testing.T) { user, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(4096) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, testFileSize, info.Size()) } info, err = os.Stat(filepath.Join(user.HomeDir, testFileName)) if assert.NoError(t, err) { assert.Greater(t, info.Size(), testFileSize) } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestFsPermissionErrors(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } user, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testDir := "tDir" err = client.Mkdir(testDir) assert.NoError(t, err) err = os.Chmod(user.GetHomeDir(), 0111) assert.NoError(t, err) err = client.RemoveDirectory(testDir) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(testDir, testDir+"1") assert.ErrorIs(t, err, os.ErrPermission) err = os.Chmod(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRenameErrorOutsideHomeDir(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldUploadMode := common.Config.UploadMode oldTempPath := common.Config.TempPath common.Config.UploadMode = common.UploadModeAtomicWithResume common.Config.TempPath = filepath.Clean(os.TempDir()) vfs.SetTempPath(common.Config.TempPath) u := getTestUser() u.QuotaFiles = 1000 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = os.Chmod(user.GetHomeDir(), 0555) assert.NoError(t, err) err = checkBasicSFTP(client) assert.NoError(t, err) f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.ErrorIs(t, err, os.ErrPermission) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) err = os.Chmod(user.GetHomeDir(), os.ModeDir) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.UploadMode = oldUploadMode common.Config.TempPath = oldTempPath vfs.SetTempPath(oldTempPath) } func TestResolvePathError(t *testing.T) { u := getTestUser() u.HomeDir = "relative_path" conn := common.NewBaseConnection("", common.ProtocolFTP, "", "", u) testPath := "apath" _, err := conn.ListDir(testPath) assert.Error(t, err) err = conn.CreateDir(testPath, true) assert.Error(t, err) err = conn.RemoveDir(testPath) assert.Error(t, err) err = conn.Rename(testPath, testPath+"1") assert.Error(t, err) err = conn.CreateSymlink(testPath, testPath+".sym") assert.Error(t, err) _, err = conn.DoStat(testPath, 0, false) assert.Error(t, err) err = conn.RemoveAll(testPath) assert.Error(t, err) err = conn.SetStat(testPath, &common.StatAttributes{ Atime: time.Now(), Mtime: time.Now(), }) assert.Error(t, err) u = getTestUser() u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: "relative_mapped_path", }, VirtualPath: "/vpath", }) err = os.MkdirAll(u.HomeDir, os.ModePerm) assert.NoError(t, err) conn.User = u err = conn.Rename(testPath, "/vpath/subpath") assert.Error(t, err) outHomePath := filepath.Join(os.TempDir(), testFileName) err = os.WriteFile(outHomePath, testFileContent, os.ModePerm) assert.NoError(t, err) err = os.Symlink(outHomePath, filepath.Join(u.HomeDir, testFileName+".link")) assert.NoError(t, err) err = os.WriteFile(filepath.Join(u.HomeDir, testFileName), testFileContent, os.ModePerm) assert.NoError(t, err) err = conn.CreateSymlink(testFileName, testFileName+".link") assert.Error(t, err) err = os.RemoveAll(u.GetHomeDir()) assert.NoError(t, err) err = os.Remove(outHomePath) assert.NoError(t, err) } func TestUserPasswordHashing(t *testing.T) { if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { t.Skip("this test is not supported with the memory provider") } u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.PasswordHashing.Algo = dataprovider.HashingAlgoArgon2ID err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) currentUser, err := dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.True(t, strings.HasPrefix(currentUser.Password, "$2a$")) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) u = getTestUser() user, _, err = httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) currentUser, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.True(t, strings.HasPrefix(currentUser.Password, "$argon2id$")) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestAllowList(t *testing.T) { configCopy := common.Config entries := []dataprovider.IPListEntry{ { IPOrNet: "172.18.1.1/32", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 0, }, { IPOrNet: "172.18.1.2/32", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 0, }, { IPOrNet: "10.8.7.0/24", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 5, }, { IPOrNet: "0.0.0.0/0", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 8, }, { IPOrNet: "::/0", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 8, }, } for _, e := range entries { _, resp, err := httpdtest.AddIPListEntry(e, http.StatusCreated) assert.NoError(t, err, string(resp)) } common.Config.AllowListStatus = 1 err := common.Initialize(common.Config, 0) assert.NoError(t, err) assert.True(t, common.Config.IsAllowListEnabled()) testIP := "172.18.1.1" assert.NoError(t, common.Connections.IsNewConnectionAllowed(testIP, common.ProtocolFTP)) entry := entries[0] entry.Protocols = 1 _, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusOK) assert.NoError(t, err) assert.Error(t, common.Connections.IsNewConnectionAllowed(testIP, common.ProtocolFTP)) assert.NoError(t, common.Connections.IsNewConnectionAllowed(testIP, common.ProtocolSSH)) _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) assert.NoError(t, err) entries = entries[1:] assert.Error(t, common.Connections.IsNewConnectionAllowed(testIP, common.ProtocolSSH)) assert.Error(t, common.Connections.IsNewConnectionAllowed("172.18.1.3", common.ProtocolSSH)) assert.NoError(t, common.Connections.IsNewConnectionAllowed("172.18.1.3", common.ProtocolHTTP)) assert.NoError(t, common.Connections.IsNewConnectionAllowed("10.8.7.3", common.ProtocolWebDAV)) assert.NoError(t, common.Connections.IsNewConnectionAllowed("10.8.7.4", common.ProtocolSSH)) assert.Error(t, common.Connections.IsNewConnectionAllowed("10.8.7.4", common.ProtocolFTP)) assert.NoError(t, common.Connections.IsNewConnectionAllowed("10.8.7.4", common.ProtocolHTTP)) assert.NoError(t, common.Connections.IsNewConnectionAllowed("2001:0db8::1428:57ab", common.ProtocolHTTP)) assert.Error(t, common.Connections.IsNewConnectionAllowed("2001:0db8::1428:57ab", common.ProtocolSSH)) assert.Error(t, common.Connections.IsNewConnectionAllowed("10.8.8.2", common.ProtocolWebDAV)) assert.Error(t, common.Connections.IsNewConnectionAllowed("invalid IP", common.ProtocolHTTP)) common.Config = configCopy err = common.Initialize(common.Config, 0) assert.NoError(t, err) assert.False(t, common.Config.IsAllowListEnabled()) for _, e := range entries { _, err := httpdtest.RemoveIPListEntry(e, http.StatusOK) assert.NoError(t, err) } } func TestDbDefenderErrors(t *testing.T) { if !isDbDefenderSupported() { t.Skip("this test is not supported with the current database provider") } configCopy := common.Config common.Config.DefenderConfig.Enabled = true common.Config.DefenderConfig.Driver = common.DefenderDriverProvider err := common.Initialize(common.Config, 0) assert.NoError(t, err) testIP := "127.1.1.1" hosts, err := common.GetDefenderHosts() assert.NoError(t, err) assert.Len(t, hosts, 0) common.AddDefenderEvent(testIP, common.ProtocolSSH, common.HostEventLimitExceeded) hosts, err = common.GetDefenderHosts() assert.NoError(t, err) assert.Len(t, hosts, 1) score, err := common.GetDefenderScore(testIP) assert.NoError(t, err) assert.Equal(t, 3, score) banTime, err := common.GetDefenderBanTime(testIP) assert.NoError(t, err) assert.Nil(t, banTime) err = dataprovider.Close() assert.NoError(t, err) common.AddDefenderEvent(testIP, common.ProtocolFTP, common.HostEventLimitExceeded) _, err = common.GetDefenderHosts() assert.Error(t, err) _, err = common.GetDefenderHost(testIP) assert.Error(t, err) _, err = common.GetDefenderBanTime(testIP) assert.Error(t, err) _, err = common.GetDefenderScore(testIP) assert.Error(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Hour))) assert.NoError(t, err) common.Config = configCopy err = common.Initialize(common.Config, 0) assert.NoError(t, err) } func TestDelayedQuotaUpdater(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.DelayedQuotaUpdate = 120 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) u := getTestUser() u.QuotaFiles = 100 u.TotalDataTransfer = 2000 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = dataprovider.UpdateUserQuota(&user, 10, 6000, false) assert.NoError(t, err) err = dataprovider.UpdateUserTransferQuota(&user, 100, 200, false) assert.NoError(t, err) files, size, ulSize, dlSize, err := dataprovider.GetUsedQuota(user.Username) assert.NoError(t, err) assert.Equal(t, 10, files) assert.Equal(t, int64(6000), size) assert.Equal(t, int64(100), ulSize) assert.Equal(t, int64(200), dlSize) userGet, err := dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.Equal(t, 0, userGet.UsedQuotaFiles) assert.Equal(t, int64(0), userGet.UsedQuotaSize) assert.Equal(t, int64(0), userGet.UsedUploadDataTransfer) assert.Equal(t, int64(0), userGet.UsedDownloadDataTransfer) err = dataprovider.UpdateUserQuota(&user, 10, 6000, true) assert.NoError(t, err) err = dataprovider.UpdateUserTransferQuota(&user, 100, 200, true) assert.NoError(t, err) files, size, ulSize, dlSize, err = dataprovider.GetUsedQuota(user.Username) assert.NoError(t, err) assert.Equal(t, 10, files) assert.Equal(t, int64(6000), size) assert.Equal(t, int64(100), ulSize) assert.Equal(t, int64(200), dlSize) userGet, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.Equal(t, 10, userGet.UsedQuotaFiles) assert.Equal(t, int64(6000), userGet.UsedQuotaSize) assert.Equal(t, int64(100), userGet.UsedUploadDataTransfer) assert.Equal(t, int64(200), userGet.UsedDownloadDataTransfer) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) folder := vfs.BaseVirtualFolder{ Name: "folder", MappedPath: filepath.Join(os.TempDir(), "p"), } err = dataprovider.AddFolder(&folder, "", "", "") assert.NoError(t, err) err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 6000, false) assert.NoError(t, err) files, size, err = dataprovider.GetUsedVirtualFolderQuota(folder.Name) assert.NoError(t, err) assert.Equal(t, 10, files) assert.Equal(t, int64(6000), size) folderGet, err := dataprovider.GetFolderByName(folder.Name) assert.NoError(t, err) assert.Equal(t, 0, folderGet.UsedQuotaFiles) assert.Equal(t, int64(0), folderGet.UsedQuotaSize) err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 6000, true) assert.NoError(t, err) files, size, err = dataprovider.GetUsedVirtualFolderQuota(folder.Name) assert.NoError(t, err) assert.Equal(t, 10, files) assert.Equal(t, int64(6000), size) folderGet, err = dataprovider.GetFolderByName(folder.Name) assert.NoError(t, err) assert.Equal(t, 10, folderGet.UsedQuotaFiles) assert.Equal(t, int64(6000), folderGet.UsedQuotaSize) err = dataprovider.DeleteFolder(folder.Name, "", "", "") assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestPasswordCaching(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) dbUser, err := dataprovider.UserExists(user.Username, "") assert.NoError(t, err) found, match := dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.False(t, found) assert.False(t, match) user.Password = "wrong" _, _, err = getSftpClient(user) assert.Error(t, err) found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.False(t, found) assert.False(t, match) user.Password = "" conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.True(t, found) assert.True(t, match) found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword+"_", dbUser.Password) assert.True(t, found) assert.False(t, match) found, match = dataprovider.CheckCachedUserPassword(user.Username+"_", defaultPassword, dbUser.Password) assert.False(t, found) assert.False(t, match) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // the password was not changed found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.True(t, found) assert.True(t, match) // the password hash will change user.Password = defaultPassword _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) dbUser, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.False(t, found) assert.False(t, match) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.True(t, found) assert.True(t, match) //change password newPassword := defaultPassword + "mod" user.Password = newPassword _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) dbUser, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) found, match = dataprovider.CheckCachedUserPassword(user.Username, newPassword, dbUser.Password) assert.False(t, found) assert.False(t, match) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.True(t, found) assert.False(t, match) found, match = dataprovider.CheckCachedUserPassword(user.Username, newPassword, dbUser.Password) assert.True(t, found) assert.True(t, match) // update the password err = dataprovider.UpdateUserPassword(user.Username, defaultPassword, "", "", "") assert.NoError(t, err) dbUser, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) // the stored hash does not match found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.False(t, found) assert.False(t, match) user.Password = defaultPassword conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.True(t, found) assert.True(t, match) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) assert.False(t, found) assert.False(t, match) } func TestEventRule(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: "http://localhost", Timeout: 20, Method: http.MethodGet, }, }, } a2 := dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeBackup, } a3 := dataprovider.BaseEventAction{ Name: "action3", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test1@example.com", "test2@example.com"}, Bcc: []string{"test3@example.com"}, Subject: `New "{{.Event}}" from "{{.Name}}" status {{.StatusString}}`, Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}} Data: {{.ObjectData}} {{.ErrorString}}", }, }, } a4 := dataprovider.BaseEventAction{ Name: "action4", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"failure@example.com"}, Subject: `Failed "{{.Event}}" from "{{.Name}}"`, Body: "Fs path {{.FsPath}}, protocol: {{.Protocol}}, IP: {{.IP}} {{.ErrorString}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) assert.NoError(t, err) action4, _, err := httpdtest.AddEventAction(a4, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, Options: dataprovider.ConditionOptions{ EventStatuses: []int{1}, FsPaths: []dataprovider.ConditionPattern{ { Pattern: "/subdir/*.dat", }, { Pattern: "/**/*.txt", }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, StopOnFailure: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 3, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action4.Name, }, Order: 4, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) r2 := dataprovider.EventRule{ Name: "test rule2", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"download"}, Options: dataprovider.ConditionOptions{ FsPaths: []dataprovider.ConditionPattern{ { Pattern: "/**/*.dat", }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action4.Name, }, Order: 2, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, }, } rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err) r3 := dataprovider.EventRule{ Name: "test rule3", Status: 1, Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"delete"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 1, }, }, } rule3, _, err := httpdtest.AddEventRule(r3, http.StatusCreated) assert.NoError(t, err) uploadScriptPath := filepath.Join(os.TempDir(), "upload.sh") u := getTestUser() u.DownloadDataTransfer = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) movedFileName := "moved.dat" movedPath := filepath.Join(user.HomeDir, movedFileName) err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, "", 0), 0755) assert.NoError(t, err) dataprovider.EnabledActionCommands = []string{uploadScriptPath} defer func() { dataprovider.EnabledActionCommands = nil }() action1.Type = dataprovider.ActionTypeCommand action1.Options = dataprovider.BaseEventActionOptions{ CmdConfig: dataprovider.EventActionCommandConfig{ Cmd: uploadScriptPath, Timeout: 10, EnvVars: []dataprovider.KeyValue{ { Key: "SFTPGO_ACTION_PATH", Value: "{{.FsPath}}", }, { Key: "CUSTOM_ENV_VAR", Value: "value", }, }, }, } action1, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) assert.NoError(t, err) dirName := "subdir" conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() size := int64(32768) // rule conditions does not match err = writeSFTPFileNoCheck(testFileName, size, client) assert.NoError(t, err) info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, size, info.Size()) } err = client.Mkdir(dirName) assert.NoError(t, err) err = client.Mkdir("subdir1") assert.NoError(t, err) // rule conditions match lastReceivedEmail.reset() err = writeSFTPFileNoCheck(path.Join(dirName, testFileName), size, client) assert.NoError(t, err) _, err = client.Stat(path.Join(dirName, testFileName)) assert.Error(t, err) info, err = client.Stat(movedFileName) if assert.NoError(t, err) { assert.Equal(t, size, info.Size()) } assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 3) assert.True(t, slices.Contains(email.To, "test1@example.com")) assert.True(t, slices.Contains(email.To, "test2@example.com")) assert.True(t, slices.Contains(email.To, "test3@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "upload" from "%s" status OK`, user.Username)) // test the failure action, we download a file that exceeds the transfer quota limit err = writeSFTPFileNoCheck(path.Join("subdir1", testFileName), 1*1024*1024+65535, client) assert.NoError(t, err) lastReceivedEmail.reset() f, err := client.Open(path.Join("subdir1", testFileName)) assert.NoError(t, err) _, err = io.ReadAll(f) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) } err = f.Close() assert.Error(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 3) assert.True(t, slices.Contains(email.To, "test1@example.com")) assert.True(t, slices.Contains(email.To, "test2@example.com")) assert.True(t, slices.Contains(email.To, "test3@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "download" from "%s" status KO`, user.Username)) assert.Contains(t, email.Data, `"download" failed`) assert.Contains(t, email.Data, common.ErrReadQuotaExceeded.Error()) _, err = httpdtest.UpdateTransferQuotaUsage(user, "", http.StatusOK) assert.NoError(t, err) // remove the upload script to test the failure action err = os.Remove(uploadScriptPath) assert.NoError(t, err) lastReceivedEmail.reset() err = writeSFTPFileNoCheck(path.Join(dirName, testFileName), size, client) assert.Error(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "failure@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: Failed "upload" from "%s"`, user.Username)) assert.Contains(t, email.Data, fmt.Sprintf(`action %q failed`, action1.Name)) // now test the download rule lastReceivedEmail.reset() f, err = client.Open(movedFileName) assert.NoError(t, err) contents, err := io.ReadAll(f) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) assert.Len(t, contents, int(size)) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 3) assert.True(t, slices.Contains(email.To, "test1@example.com")) assert.True(t, slices.Contains(email.To, "test2@example.com")) assert.True(t, slices.Contains(email.To, "test3@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "download" from "%s"`, user.Username)) } // test upload action command with arguments action1.Options.CmdConfig.Args = []string{"{{.Event}}", "{{.VirtualPath}}", "custom_arg"} action1, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) assert.NoError(t, err) uploadLogFilePath := filepath.Join(os.TempDir(), "upload.log") err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, uploadLogFilePath, 0), 0755) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFileNoCheck(path.Join(dirName, testFileName), 123, client) assert.NoError(t, err) logContent, err := os.ReadFile(uploadLogFilePath) assert.NoError(t, err) assert.Equal(t, fmt.Sprintf("upload %s custom_arg", util.CleanPath(path.Join(dirName, testFileName))), strings.TrimSpace(string(logContent))) err = os.Remove(uploadLogFilePath) assert.NoError(t, err) lastReceivedEmail.reset() assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) } lastReceivedEmail.reset() _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 3) assert.True(t, slices.Contains(email.To, "test1@example.com")) assert.True(t, slices.Contains(email.To, "test2@example.com")) assert.True(t, slices.Contains(email.To, "test3@example.com")) assert.Contains(t, email.Data, `Subject: New "delete" from "admin"`) _, err = httpdtest.RemoveEventRule(rule3, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action4, http.StatusOK) assert.NoError(t, err) lastReceivedEmail.reset() _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleStatues(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test6@example.com"}, Subject: `New "{{.Event}}" error`, Body: "{{.ErrorString}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r := dataprovider.EventRule{ Name: "rule", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, Options: dataprovider.ConditionOptions{ EventStatuses: []int{3}, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule, resp, err := httpdtest.AddEventRule(r, http.StatusCreated) assert.NoError(t, err, string(resp)) u := getTestUser() u.UploadDataTransfer = 1 u.DownloadDataTransfer = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(999999) err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) f, err := client.Open(testFileName) assert.NoError(t, err) contents := make([]byte, testFileSize) n, err := io.ReadFull(f, contents) assert.NoError(t, err) assert.Equal(t, int(testFileSize), n) assert.Len(t, contents, int(testFileSize)) err = f.Close() assert.NoError(t, err) lastReceivedEmail.reset() assert.Eventually(t, func() bool { return lastReceivedEmail.get().From == "" }, 600*time.Millisecond, 500*time.Millisecond) err = writeSFTPFile(testFileName, testFileSize, client) assert.Error(t, err) lastReceivedEmail.reset() assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test6@example.com")) assert.Contains(t, email.Data, `Subject: New "upload" error`) assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error()) } _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleDisabledCommand(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) saveObjectScriptPath := filepath.Join(os.TempDir(), "provider.sh") outPath := filepath.Join(os.TempDir(), "provider_out.json") err = os.WriteFile(saveObjectScriptPath, getSaveProviderObjectScriptContent(outPath, 0), 0755) assert.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeCommand, Options: dataprovider.BaseEventActionOptions{ CmdConfig: dataprovider.EventActionCommandConfig{ Cmd: saveObjectScriptPath, Timeout: 10, EnvVars: []dataprovider.KeyValue{ { Key: "SFTPGO_OBJECT_DATA", Value: "{{.ObjectData}}", }, }, }, }, } a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test3@example.com"}, Subject: `New "{{.Event}}" from "{{.Name}}"`, Body: "Object name: {{.ObjectName}} object type: {{.ObjectType}} Data: {{.ObjectData}}", }, }, } a3 := dataprovider.BaseEventAction{ Name: "a3", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"failure@example.com"}, Subject: `Failed "{{.Event}}" from "{{.Name}}"`, Body: "Object name: {{.ObjectName}} object type: {{.ObjectType}}, IP: {{.IP}}", }, }, } _, _, err = httpdtest.AddEventAction(a1, http.StatusBadRequest) assert.NoError(t, err) // Enable the command to allow saving dataprovider.EnabledActionCommands = []string{a1.Options.CmdConfig.Cmd} action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) assert.NoError(t, err) r := dataprovider.EventRule{ Name: "rule", Status: 1, Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"add"}, Options: dataprovider.ConditionOptions{ ProviderObjects: []string{"folder"}, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ StopOnFailure: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 3, Options: dataprovider.EventActionOptions{ IsFailureAction: true, StopOnFailure: true, }, }, }, } rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) assert.NoError(t, err) // restrict command execution dataprovider.EnabledActionCommands = nil lastReceivedEmail.reset() // create a folder to trigger the rule folder := vfs.BaseVirtualFolder{ Name: "ftest failed command", MappedPath: filepath.Join(os.TempDir(), "p"), } folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated) assert.NoError(t, err) assert.NoFileExists(t, outPath) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "failure@example.com")) assert.Contains(t, email.Data, `Subject: Failed "add" from "admin"`) assert.Contains(t, email.Data, fmt.Sprintf("Object name: %s object type: folder", folder.Name)) lastReceivedEmail.reset() _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) assert.NoError(t, err) } func TestEventRuleProviderEvents(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) saveObjectScriptPath := filepath.Join(os.TempDir(), "provider.sh") outPath := filepath.Join(os.TempDir(), "provider_out.json") err = os.WriteFile(saveObjectScriptPath, getSaveProviderObjectScriptContent(outPath, 0), 0755) assert.NoError(t, err) dataprovider.EnabledActionCommands = []string{saveObjectScriptPath} defer func() { dataprovider.EnabledActionCommands = nil }() a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeCommand, Options: dataprovider.BaseEventActionOptions{ CmdConfig: dataprovider.EventActionCommandConfig{ Cmd: saveObjectScriptPath, Timeout: 10, EnvVars: []dataprovider.KeyValue{ { Key: "SFTPGO_OBJECT_DATA", Value: "{{.ObjectData}}", }, }, }, }, } a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test3@example.com"}, Subject: `New "{{.Event}}" from "{{.Name}}"`, Body: "Object name: {{.ObjectName}} object type: {{.ObjectType}} Data: {{.ObjectData}}", }, }, } a3 := dataprovider.BaseEventAction{ Name: "a3", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"failure@example.com"}, Subject: `Failed "{{.Event}}" from "{{.Name}}"`, Body: "Object name: {{.ObjectName}} object type: {{.ObjectType}}, IP: {{.IP}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) assert.NoError(t, err) r := dataprovider.EventRule{ Name: "rule", Status: 1, Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"update"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ StopOnFailure: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 3, Options: dataprovider.EventActionOptions{ IsFailureAction: true, StopOnFailure: true, }, }, }, } rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) assert.NoError(t, err) lastReceivedEmail.reset() // create and update a folder to trigger the rule folder := vfs.BaseVirtualFolder{ Name: "ftest rule", MappedPath: filepath.Join(os.TempDir(), "p"), } folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated) assert.NoError(t, err) // no action is triggered on add assert.NoFileExists(t, outPath) // update the folder _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) assert.NoError(t, err) if assert.Eventually(t, func() bool { _, err := os.Stat(outPath) return err == nil }, 2*time.Second, 100*time.Millisecond) { content, err := os.ReadFile(outPath) assert.NoError(t, err) var folderGet vfs.BaseVirtualFolder err = json.Unmarshal(content, &folderGet) assert.NoError(t, err) assert.Equal(t, folder, folderGet) err = os.Remove(outPath) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test3@example.com")) assert.Contains(t, email.Data, `Subject: New "update" from "admin"`) } // now delete the script to generate an error lastReceivedEmail.reset() err = os.Remove(saveObjectScriptPath) assert.NoError(t, err) _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) assert.NoError(t, err) assert.NoFileExists(t, outPath) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "failure@example.com")) assert.Contains(t, email.Data, `Subject: Failed "update" from "admin"`) assert.Contains(t, email.Data, fmt.Sprintf("Object name: %s object type: folder", folder.Name)) lastReceivedEmail.reset() // generate an error for the failure action smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) assert.NoError(t, err) assert.NoFileExists(t, outPath) email = lastReceivedEmail.get() assert.Len(t, email.To, 0) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) assert.NoError(t, err) } func TestEventRuleFsActions(t *testing.T) { dirsToCreate := []string{ "/basedir/1", "/basedir/sub/2", "/basedir/3", } a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionMkdirs, MkDirs: dirsToCreate, }, }, } a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, Renames: []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/{{.VirtualDirPath}}/{{.ObjectName}}", Value: "/{{.ObjectName}}_renamed", }, }, }, }, }, } a3 := dataprovider.BaseEventAction{ Name: "a3", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionDelete, Deletes: []string{"/{{.ObjectName}}_renamed"}, }, }, } a4 := dataprovider.BaseEventAction{ Name: "a4", Type: dataprovider.ActionTypeFolderQuotaReset, } a5 := dataprovider.BaseEventAction{ Name: "a5", Type: dataprovider.ActionTypeUserQuotaReset, } a6 := dataprovider.BaseEventAction{ Name: "a6", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionExist, Exist: []string{"/{{.VirtualPath}}"}, }, }, } action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err, string(resp)) action2, resp, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err, string(resp)) action3, resp, err := httpdtest.AddEventAction(a3, http.StatusCreated) assert.NoError(t, err, string(resp)) action4, resp, err := httpdtest.AddEventAction(a4, http.StatusCreated) assert.NoError(t, err, string(resp)) action5, resp, err := httpdtest.AddEventAction(a5, http.StatusCreated) assert.NoError(t, err, string(resp)) action6, resp, err := httpdtest.AddEventAction(a6, http.StatusCreated) assert.NoError(t, err, string(resp)) r1 := dataprovider.EventRule{ Name: "r1", Status: 1, Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"add"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } r2 := dataprovider.EventRule{ Name: "r2", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action5.Name, }, Order: 2, }, }, } r3 := dataprovider.EventRule{ Name: "r3", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"mkdir"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action6.Name, }, Order: 2, }, }, } r4 := dataprovider.EventRule{ Name: "r4", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"rmdir"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action4.Name, }, Order: 1, }, }, } r5 := dataprovider.EventRule{ Name: "r5", Status: 1, Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"add"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action4.Name, }, Order: 1, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err) rule3, _, err := httpdtest.AddEventRule(r3, http.StatusCreated) assert.NoError(t, err) rule4, _, err := httpdtest.AddEventRule(r4, http.StatusCreated) assert.NoError(t, err) rule5, _, err := httpdtest.AddEventRule(r5, http.StatusCreated) assert.NoError(t, err) folderMappedPath := filepath.Join(os.TempDir(), "folder") err = os.MkdirAll(folderMappedPath, os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(folderMappedPath, "file.txt"), []byte("1"), 0666) assert.NoError(t, err) folder, _, err := httpdtest.AddFolder(vfs.BaseVirtualFolder{ Name: "test folder", MappedPath: folderMappedPath, }, http.StatusCreated) assert.NoError(t, err) assert.Eventually(t, func() bool { folderGet, _, err := httpdtest.GetFolderByName(folder.Name, http.StatusOK) if err != nil { return false } return folderGet.UsedQuotaFiles == 1 && folderGet.UsedQuotaSize == 1 }, 2*time.Second, 100*time.Millisecond) u := getTestUser() u.Filters.DisableFsChecks = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() // check initial directories creation for _, dir := range dirsToCreate { assert.Eventually(t, func() bool { _, err := client.Stat(dir) return err == nil }, 2*time.Second, 100*time.Millisecond) } // upload a file and check the sync rename size := int64(32768) err = writeSFTPFileNoCheck(path.Join("basedir", testFileName), size, client) assert.NoError(t, err) _, err = client.Stat(path.Join("basedir", testFileName)) assert.Error(t, err) info, err := client.Stat(testFileName + "_renamed") //nolint:goconst if assert.NoError(t, err) { assert.Equal(t, size, info.Size()) } assert.NoError(t, err) assert.Eventually(t, func() bool { userGet, _, err := httpdtest.GetUserByUsername(user.Username, http.StatusOK) if err != nil { return false } return userGet.UsedQuotaFiles == 1 && userGet.UsedQuotaSize == size }, 2*time.Second, 100*time.Millisecond) for i := 0; i < 2; i++ { err = client.Mkdir(testFileName) assert.NoError(t, err) assert.Eventually(t, func() bool { _, err = client.Stat(testFileName + "_renamed") return err != nil }, 2*time.Second, 100*time.Millisecond) err = client.RemoveDirectory(testFileName) assert.NoError(t, err) } err = client.Mkdir(testFileName + "_renamed") assert.NoError(t, err) err = client.Mkdir(testFileName) assert.NoError(t, err) assert.Eventually(t, func() bool { _, err = client.Stat(testFileName + "_renamed") return err != nil }, 2*time.Second, 100*time.Millisecond) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(folderMappedPath) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule3, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule4, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule5, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action4, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action5, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action6, http.StatusOK) assert.NoError(t, err) } func TestEventActionObjectBaseName(t *testing.T) { a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, Renames: []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/{{.VirtualDirPath}}/{{.ObjectName}}", Value: "/{{.ObjectBaseName}}", }, }, }, }, }, } action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err, string(resp)) r1 := dataprovider.EventRule{ Name: "r2", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testDir := "test dir name" err = client.Mkdir(testDir) fileSize := int64(32768) assert.NoError(t, err) err = writeSFTPFileNoCheck(path.Join(testDir, testFileName), fileSize, client) assert.NoError(t, err) _, err = client.Stat(path.Join(testDir, testFileName)) assert.ErrorIs(t, err, os.ErrNotExist) _, err = client.Stat(strings.TrimSuffix(testFileName, path.Ext(testFileName))) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) } func TestUploadEventRule(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test1@example.com"}, Subject: `New "{{.Event}}" from "{{.Name}}" status {{.StatusString}}`, Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}} Data: {{.ObjectData}} {{.ErrorString}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, Options: dataprovider.ConditionOptions{ FsPaths: []dataprovider.ConditionPattern{ { Pattern: "/**/*.filepart", InverseMatch: true, }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() err = writeSFTPFileNoCheck("/test.filepart", 32768, client) assert.NoError(t, err) email := lastReceivedEmail.get() assert.Empty(t, email.From) lastReceivedEmail.reset() err = writeSFTPFileNoCheck(testFileName, 32768, client) assert.NoError(t, err) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.Contains(t, email.Data, `Subject: New "upload"`) } r2 := dataprovider.EventRule{ Name: "test rule2", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"rename"}, Options: dataprovider.ConditionOptions{ FsPaths: []dataprovider.ConditionPattern{ { Pattern: "/**/*.filepart", }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() tempName := "file.filepart" lastReceivedEmail.reset() err = writeSFTPFileNoCheck(tempName, 32768, client) assert.NoError(t, err) email := lastReceivedEmail.get() assert.Empty(t, email.From) lastReceivedEmail.reset() err = client.Rename(tempName, testFileName) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.Contains(t, email.Data, `Subject: New "rename"`) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRulePreDelete(t *testing.T) { movePath := "recycle bin" a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, Renames: []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/{{.VirtualPath}}", Value: fmt.Sprintf("/%s/{{.VirtualPath}}", movePath), }, UpdateModTime: true, }, }, }, }, } action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err, string(resp)) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"pre-delete"}, Options: dataprovider.ConditionOptions{ FsPaths: []dataprovider.ConditionPattern{ { Pattern: fmt.Sprintf("/%s/**", movePath), InverseMatch: true, }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) f := vfs.BaseVirtualFolder{ Name: movePath, MappedPath: filepath.Join(os.TempDir(), movePath), } _, _, err = httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.QuotaFiles = 1000 u.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: movePath, }, VirtualPath: "/" + movePath, QuotaFiles: 1000, }, } localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaFiles = 1000 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testDir := "sub dir" err = client.MkdirAll(testDir) assert.NoError(t, err) err = writeSFTPFile(testFileName, 100, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, testFileName), 100, client) assert.NoError(t, err) modTime := time.Now().Add(-36 * time.Hour) err = client.Chtimes(testFileName, modTime, modTime) assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) err = client.Remove(path.Join(testDir, testFileName)) assert.NoError(t, err) // check files _, err = client.Stat(testFileName) assert.ErrorIs(t, err, os.ErrNotExist) _, err = client.Stat(path.Join(testDir, testFileName)) assert.ErrorIs(t, err, os.ErrNotExist) info, err := client.Stat(path.Join("/", movePath, testFileName)) assert.NoError(t, err) diff := math.Abs(time.Until(info.ModTime()).Seconds()) assert.LessOrEqual(t, diff, float64(2)) _, err = client.Stat(path.Join("/", movePath, testDir, testFileName)) assert.NoError(t, err) // check quota user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) if user.Username == localUser.Username { assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) folder, _, err := httpdtest.GetFolderByName(movePath, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, folder.UsedQuotaFiles) assert.Equal(t, int64(200), folder.UsedQuotaSize) } else { assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(100), user.UsedQuotaSize) } // pre-delete action is not executed in movePath err = client.Remove(path.Join("/", movePath, testFileName)) assert.NoError(t, err) if user.Username == localUser.Username { // check quota folder, _, err := httpdtest.GetFolderByName(movePath, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, folder.UsedQuotaFiles) assert.Equal(t, int64(100), folder.UsedQuotaSize) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } } } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: movePath}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(filepath.Join(os.TempDir(), movePath)) assert.NoError(t, err) } func TestEventRulePreDownloadUpload(t *testing.T) { testDir := "/d" a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionMkdirs, MkDirs: []string{testDir}, }, }, } action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err, string(resp)) a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, Renames: []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/missing source", Value: "/missing target", }, }, }, }, }, } action2, resp, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err, string(resp)) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"pre-download", "pre-upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() // the rule will always succeed, so uploads/downloads will work err = writeSFTPFile(testFileName, 100, client) assert.NoError(t, err) _, err = client.Stat(testDir) assert.NoError(t, err) err = client.RemoveDirectory(testDir) assert.NoError(t, err) f, err := client.Open(testFileName) assert.NoError(t, err) contents := make([]byte, 100) n, err := io.ReadFull(f, contents) assert.NoError(t, err) assert.Equal(t, int(100), n) err = f.Close() assert.NoError(t, err) // disable the rule rule1.Status = 0 _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) assert.NoError(t, err) err = client.RemoveDirectory(testDir) assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) err = writeSFTPFile(testFileName, 100, client) assert.NoError(t, err) _, err = client.Stat(testDir) assert.ErrorIs(t, err, fs.ErrNotExist) // now update the rule so that it will always fail rule1.Status = 1 rule1.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, } _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = client.Open(testFileName) assert.ErrorIs(t, err, os.ErrPermission) err = client.Remove(testFileName) assert.NoError(t, err) err = writeSFTPFile(testFileName, 100, client) assert.ErrorIs(t, err, os.ErrPermission) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestEventActionCommandEnvVars(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } envName := "MY_ENV" uploadScriptPath := filepath.Join(os.TempDir(), "upload.sh") dataprovider.EnabledActionCommands = []string{uploadScriptPath} defer func() { dataprovider.EnabledActionCommands = nil }() err := os.WriteFile(uploadScriptPath, getUploadScriptEnvContent(envName), 0755) assert.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeCommand, Options: dataprovider.BaseEventActionOptions{ CmdConfig: dataprovider.EventActionCommandConfig{ Cmd: uploadScriptPath, Timeout: 10, EnvVars: []dataprovider.KeyValue{ { Key: envName, Value: "$", }, }, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFileNoCheck(testFileName, 100, client) assert.Error(t, err) } os.Setenv(envName, "1") defer os.Unsetenv(envName) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFileNoCheck(testFileName, 100, client) assert.NoError(t, err) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.Remove(uploadScriptPath) assert.NoError(t, err) } func TestFsActionCopy(t *testing.T) { dirCopy := "/dircopy" a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionCopy, Copy: []dataprovider.KeyValue{ { Key: "/{{.VirtualPath}}/", Value: dirCopy + "/", }, }, }, }, } action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err, string(resp)) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) g1 := dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "group1", }, UserSettings: dataprovider.GroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{ Permissions: map[string][]string{ // Restrict permissions in copyPath to check that action // will have full permissions anyway. dirCopy: {dataprovider.PermListItems, dataprovider.PermDelete}, }, }, }, } group1, resp, err := httpdtest.AddGroup(g1, http.StatusCreated) assert.NoError(t, err, string(resp)) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: group1.Name, Type: sdk.GroupTypePrimary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFile(testFileName, 100, client) assert.NoError(t, err) _, err = client.Stat(path.Join(dirCopy, testFileName)) assert.NoError(t, err) action1.Options.FsConfig.Copy = []dataprovider.KeyValue{ { Key: "/missing path", Value: "/copied path", }, } _, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) assert.NoError(t, err) // copy a missing path will fail err = writeSFTPFile(testFileName, 100, client) assert.Error(t, err) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) } func TestEventFsActionsGroupFilters(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"example@example.net"}, Subject: `New "{{.Event}}" from "{{.Name}}" status {{.StatusString}}`, Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}} {{.ErrorString}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, Options: dataprovider.ConditionOptions{ GroupNames: []dataprovider.ConditionPattern{ { Pattern: "group*", }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() // the user has no group, so the rule does not match lastReceivedEmail.reset() err = writeSFTPFile(testFileName, 32, client) assert.NoError(t, err) assert.Empty(t, lastReceivedEmail.get().From) } g1 := dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "agroup1", }, } group1, _, err := httpdtest.AddGroup(g1, http.StatusCreated) assert.NoError(t, err) g2 := dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "group2", }, } group2, _, err := httpdtest.AddGroup(g2, http.StatusCreated) assert.NoError(t, err) user.Groups = []sdk.GroupMapping{ { Name: group1.Name, Type: sdk.GroupTypePrimary, }, } _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() // the group does not match lastReceivedEmail.reset() err = writeSFTPFile(testFileName, 32, client) assert.NoError(t, err) assert.Empty(t, lastReceivedEmail.get().From) } user.Groups = append(user.Groups, sdk.GroupMapping{ Name: group2.Name, Type: sdk.GroupTypeSecondary, }) _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() // the group matches lastReceivedEmail.reset() err = writeSFTPFile(testFileName, 32, client) assert.NoError(t, err) assert.NotEmpty(t, lastReceivedEmail.get().From) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group2, http.StatusOK) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventProviderActionGroupFilters(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"example@example.net"}, Subject: `New "{{.Event}}" from "{{.Name}}"`, Body: "IP: {{.IP}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"add", "update"}, Options: dataprovider.ConditionOptions{ GroupNames: []dataprovider.ConditionPattern{ { Pattern: "group_*", }, }, ProviderObjects: []string{"user"}, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) g1 := dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "agroup_1", }, } group1, _, err := httpdtest.AddGroup(g1, http.StatusCreated) assert.NoError(t, err) g2 := dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "group_2", }, } group2, _, err := httpdtest.AddGroup(g2, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: group2.Name, Type: sdk.GroupTypePrimary, }, } lastReceivedEmail.reset() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) user.Groups = []sdk.GroupMapping{ { Name: group1.Name, Type: sdk.GroupTypePrimary, }, } lastReceivedEmail.reset() user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) time.Sleep(300 * time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 0) user.Groups = []sdk.GroupMapping{ { Name: group2.Name, Type: sdk.GroupTypePrimary, }, } lastReceivedEmail.reset() user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group2, http.StatusOK) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestBackupAsAttachment(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1 with space", Type: dataprovider.ActionTypeBackup, } a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.com"}, Subject: `"{{.Event}} {{.StatusString}}"`, Body: "Domain: {{.Name}}", Attachments: []string{"/{{.VirtualPath}}"}, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rule certificate", Status: 1, Trigger: dataprovider.EventTriggerCertificate, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) lastReceivedEmail.reset() renewalEvent := "Certificate renewal" common.HandleCertificateEvent(common.EventParams{ Name: "example.com", Timestamp: time.Now(), Status: 1, Event: renewalEvent, }) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, renewalEvent)) assert.Contains(t, email.Data, `Domain: example.com`) assert.Contains(t, email.Data, "Content-Type: application/json") _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventActionHTTPMultipart(t *testing.T) { a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: fmt.Sprintf("http://%s/multipart", httpAddr), Method: http.MethodPut, Parts: []dataprovider.HTTPPart{ { Name: "part1", Headers: []dataprovider.KeyValue{ { Key: "Content-Type", Value: "application/json", }, }, Body: `{"FilePath": "{{.VirtualPath}}"}`, }, { Name: "file", Filepath: "/{{.VirtualPath}}", }, }, }, }, } action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err, string(resp)) r1 := dataprovider.EventRule{ Name: "test http multipart", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, Order: 1, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) // now add an missing file to the http multipart action action1.Options.HTTPConfig.Parts = append(action1.Options.HTTPConfig.Parts, dataprovider.HTTPPart{ Name: "file1", Filepath: "/missing", }) _, resp, err = httpdtest.UpdateEventAction(action1, http.StatusOK) assert.NoError(t, err, string(resp)) f, err = client.Create("testfile.txt") assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.Error(t, err) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestEventActionCompress(t *testing.T) { a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionCompress, Compress: dataprovider.EventActionFsCompress{ Name: "/{{.VirtualPath}}.zip", Paths: []string{"/{{.VirtualPath}}"}, }, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test compress", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.QuotaFiles = 1000 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.FsConfig.SFTPConfig.BufferSize = 1 u.QuotaFiles = 1000 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getCryptFsUser() u.QuotaFiles = 1000 cryptFsUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser, cryptFsUser} { // cleanup home dir err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) rule1.Conditions.Options.Names = []dataprovider.ConditionPattern{ { Pattern: user.Username, }, } _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() expectedQuotaSize := int64(len(testFileContent)) expectedQuotaFiles := 1 if user.Username == cryptFsUser.Username { encryptedFileSize, err := getEncryptedFileSize(expectedQuotaSize) assert.NoError(t, err) expectedQuotaSize = encryptedFileSize } f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) info, err := client.Stat(testFileName + ".zip") //nolint:goconst if assert.NoError(t, err) { assert.Greater(t, info.Size(), int64(0)) // check quota archiveSize := info.Size() if user.Username == cryptFsUser.Username { encryptedFileSize, err := getEncryptedFileSize(archiveSize) assert.NoError(t, err) archiveSize = encryptedFileSize } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles+1, user.UsedQuotaFiles, "quota file does no match for user %q", user.Username) assert.Equal(t, expectedQuotaSize+archiveSize, user.UsedQuotaSize, "quota size does no match for user %q", user.Username) } // now overwrite the same file f, err = client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) info, err = client.Stat(testFileName + ".zip") if assert.NoError(t, err) { assert.Greater(t, info.Size(), int64(0)) archiveSize := info.Size() if user.Username == cryptFsUser.Username { encryptedFileSize, err := getEncryptedFileSize(archiveSize) assert.NoError(t, err) archiveSize = encryptedFileSize } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles+1, user.UsedQuotaFiles, "quota file after overwrite does no match for user %q", user.Username) assert.Equal(t, expectedQuotaSize+archiveSize, user.UsedQuotaSize, "quota size after overwrite does no match for user %q", user.Username) } } if user.Username == localUser.Username { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(cryptFsUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(cryptFsUser.GetHomeDir()) assert.NoError(t, err) } func TestEventActionCompressQuotaErrors(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notify@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) testDir := "archiveDir" zipPath := "/archive.zip" a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionCompress, Compress: dataprovider.EventActionFsCompress{ Name: zipPath, Paths: []string{"/" + testDir}, }, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.com"}, Subject: `"Compress failed"`, Body: "Error: {{.ErrorString}}", }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test compress", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"rename"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, Order: 2, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) fileSize := int64(100) u := getTestUser() u.QuotaSize = 10 * fileSize user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.MkdirAll(path.Join(testDir, "1", "1")) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, "1", testFileName), fileSize, client) assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "2", "2")) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, "2", testFileName), fileSize, client) assert.NoError(t, err) err = client.Symlink(path.Join(testDir, "2", testFileName), path.Join(testDir, "2", testFileName+"_link")) assert.NoError(t, err) // trigger the compress action err = client.Mkdir("a") assert.NoError(t, err) err = client.Rename("a", "b") assert.NoError(t, err) assert.Eventually(t, func() bool { _, err := client.Stat(zipPath) return err == nil }, 3*time.Second, 100*time.Millisecond) err = client.Remove(zipPath) assert.NoError(t, err) // add other 6 file, the compress action should fail with a quota error err = writeSFTPFile(path.Join(testDir, "1", "1", testFileName), fileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, "2", "2", testFileName), fileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, "1", "1", testFileName+"1"), fileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, "2", "2", testFileName+"2"), fileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, "1", testFileName+"1"), fileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, "2", testFileName+"2"), fileSize, client) assert.NoError(t, err) lastReceivedEmail.reset() err = client.Rename("b", "a") assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3*time.Second, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, `Subject: "Compress failed"`) assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error()) // update quota size so the user is already overquota user.QuotaSize = 7 * fileSize _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) lastReceivedEmail.reset() err = client.Rename("a", "b") assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3*time.Second, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, `Subject: "Compress failed"`) assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error()) // remove the path to compress to trigger an error for size estimation out, err := runSSHCommand(fmt.Sprintf("sftpgo-remove %s", testDir), user) assert.NoError(t, err, string(out)) lastReceivedEmail.reset() err = client.Rename("b", "a") assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3*time.Second, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, `Subject: "Compress failed"`) assert.Contains(t, email.Data, "unable to estimate archive size") } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventActionCompressQuotaFolder(t *testing.T) { testDir := "/folder" a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionCompress, Compress: dataprovider.EventActionFsCompress{ Name: "/{{.VirtualPath}}.zip", Paths: []string{"/{{.VirtualPath}}", testDir}, }, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test compress", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.QuotaFiles = 1000 mappedPath := filepath.Join(os.TempDir(), "virtualpath") folderName := filepath.Base(mappedPath) vdirPath := "/virtualpath" f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err = httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, QuotaSize: -1, QuotaFiles: -1, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir(testDir) assert.NoError(t, err) expectedQuotaSize := int64(len(testFileContent)) expectedQuotaFiles := 1 err = client.Symlink(path.Join(testDir, testFileName), path.Join(testDir, testFileName+"_link")) assert.NoError(t, err) f, err := client.Create(path.Join(testDir, testFileName)) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) info, err := client.Stat(path.Join(testDir, testFileName) + ".zip") if assert.NoError(t, err) { assert.Greater(t, info.Size(), int64(0)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) expectedQuotaFiles++ expectedQuotaSize += info.Size() assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) } vfolder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, vfolder.UsedQuotaFiles) assert.Equal(t, int64(0), vfolder.UsedQuotaSize) // upload in the virtual path f, err = client.Create(path.Join(vdirPath, testFileName)) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) info, err = client.Stat(path.Join(vdirPath, testFileName) + ".zip") if assert.NoError(t, err) { assert.Greater(t, info.Size(), int64(0)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) expectedQuotaFiles += 2 expectedQuotaSize += info.Size() + int64(len(testFileContent)) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) vfolder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, vfolder.UsedQuotaFiles) assert.Equal(t, int64(0), vfolder.UsedQuotaSize) } } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestEventActionCompressErrors(t *testing.T) { a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionCompress, Compress: dataprovider.EventActionFsCompress{ Name: "/{{.VirtualPath}}.zip", Paths: []string{"/{{.VirtualPath}}.zip"}, // cannot compress itself }, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test compress", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.Error(t, err) } // try to compress a missing file action1.Options.FsConfig.Compress.Paths = []string{"/missing file"} _, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.Error(t, err) } // try to overwrite a directory testDir := "/adir" action1.Options.FsConfig.Compress.Name = testDir action1.Options.FsConfig.Compress.Paths = []string{"/{{.VirtualPath}}"} _, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir(testDir) assert.NoError(t, err) f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.Error(t, err) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestEventActionEmailAttachments(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notify@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionCompress, Compress: dataprovider.EventActionFsCompress{ Name: "/archive/{{.VirtualPath}}.zip", Paths: []string{"/{{.VirtualPath}}"}, }, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.com"}, Subject: `"{{.Event}}" from "{{.Name}}"`, Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}} {{.EscapedVirtualPath}}", Attachments: []string{"/archive/{{.VirtualPath}}.zip"}, }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test email with attachment", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) localUser, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) u := getTestSFTPUser() u.FsConfig.SFTPConfig.BufferSize = 1 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) cryptFsUser, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser, cryptFsUser} { conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, `Subject: "upload" from`) assert.Contains(t, email.Data, url.QueryEscape("/"+testFileName)) assert.Contains(t, email.Data, "Content-Disposition: attachment") } } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(cryptFsUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(cryptFsUser.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventActionsRetentionReports(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notify@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) testDir := "/d" a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeDataRetentionCheck, Options: dataprovider.BaseEventActionOptions{ RetentionConfig: dataprovider.EventActionDataRetentionConfig{ Folders: []dataprovider.FolderRetention{ { Path: testDir, Retention: 1, DeleteEmptyDirs: true, }, }, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.com"}, Subject: `"{{.Event}}" from "{{.Name}}"`, Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}}", Attachments: []string{dataprovider.RetentionReportPlaceHolder}, }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) a3 := dataprovider.BaseEventAction{ Name: "action3", Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: fmt.Sprintf("http://%s/", httpAddr), Timeout: 20, Method: http.MethodPost, Body: dataprovider.RetentionReportPlaceHolder, }, }, } action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) assert.NoError(t, err) a4 := dataprovider.BaseEventAction{ Name: "action4", Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: fmt.Sprintf("http://%s/multipart", httpAddr), Timeout: 20, Method: http.MethodPost, Parts: []dataprovider.HTTPPart{ { Name: "reports.zip", Filepath: dataprovider.RetentionReportPlaceHolder, }, }, }, }, } action4, _, err := httpdtest.AddEventAction(a4, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, StopOnFailure: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, Options: dataprovider.EventActionOptions{ ExecuteSync: true, StopOnFailure: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 3, Options: dataprovider.EventActionOptions{ ExecuteSync: true, StopOnFailure: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action4.Name, }, Order: 4, Options: dataprovider.EventActionOptions{ ExecuteSync: true, StopOnFailure: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() subdir := path.Join(testDir, "sub") err = client.MkdirAll(subdir) assert.NoError(t, err) lastReceivedEmail.reset() f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "upload" from "%s"`, user.Username)) assert.Contains(t, email.Data, "Content-Disposition: attachment") _, err = client.Stat(testDir) assert.NoError(t, err) _, err = client.Stat(subdir) assert.ErrorIs(t, err, os.ErrNotExist) err = client.Mkdir(subdir) assert.NoError(t, err) newName := path.Join(testDir, testFileName) err = client.Rename(testFileName, newName) assert.NoError(t, err) err = client.Chtimes(newName, time.Now().Add(-24*time.Hour), time.Now().Add(-24*time.Hour)) assert.NoError(t, err) lastReceivedEmail.reset() f, err = client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) _, err = client.Stat(subdir) assert.ErrorIs(t, err, os.ErrNotExist) _, err = client.Stat(subdir) assert.ErrorIs(t, err, os.ErrNotExist) } // now remove the retention check to test errors rule1.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, Options: dataprovider.EventActionOptions{ ExecuteSync: true, StopOnFailure: false, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 3, Options: dataprovider.EventActionOptions{ ExecuteSync: true, StopOnFailure: false, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action4.Name, }, Order: 4, Options: dataprovider.EventActionOptions{ ExecuteSync: true, StopOnFailure: false, }, }, } _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.Create(testFileName) assert.NoError(t, err) _, err = f.Write(testFileContent) assert.NoError(t, err) err = f.Close() assert.Error(t, err) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action4, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleFirstUploadDownloadActions(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notify@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.com"}, Subject: `"{{.Event}}" from "{{.Name}}"`, Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test first upload rule", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"first-upload"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) r2 := dataprovider.EventRule{ Name: "test first download rule", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"first-download"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(32768) lastReceivedEmail.reset() err = writeSFTPFileNoCheck(testFileName, testFileSize, client) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "first-upload" from "%s"`, user.Username)) lastReceivedEmail.reset() // a new upload will not produce a new notification err = writeSFTPFileNoCheck(testFileName+"_1", 32768, client) assert.NoError(t, err) assert.Never(t, func() bool { return lastReceivedEmail.get().From != "" }, 1000*time.Millisecond, 100*time.Millisecond) // the same for download f, err := client.Open(testFileName) assert.NoError(t, err) contents := make([]byte, testFileSize) n, err := io.ReadFull(f, contents) assert.NoError(t, err) assert.Equal(t, int(testFileSize), n) err = f.Close() assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "first-download" from "%s"`, user.Username)) // download again lastReceivedEmail.reset() f, err = client.Open(testFileName) assert.NoError(t, err) contents = make([]byte, testFileSize) n, err = io.ReadFull(f, contents) assert.NoError(t, err) assert.Equal(t, int(testFileSize), n) err = f.Close() assert.NoError(t, err) assert.Never(t, func() bool { return lastReceivedEmail.get().From != "" }, 1000*time.Millisecond, 100*time.Millisecond) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleRenameEvent(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notify@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.com"}, Subject: `"{{.Event}}" from "{{.Name}}"`, ContentType: 1, Body: `

Fs path {{.FsPath}}, Name: {{.Name}}, Target path "{{.VirtualTargetDirPath}}/{{.TargetName}}", size: {{.FileSize}}

`, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rename rule", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"rename"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Username = "test & chars" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(32768) lastReceivedEmail.reset() err = writeSFTPFileNoCheck(testFileName, testFileSize, client) assert.NoError(t, err) err = client.Mkdir("subdir") assert.NoError(t, err) err = client.Rename(testFileName, path.Join("/subdir", testFileName)) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "rename" from "%s"`, user.Username)) assert.Contains(t, email.Data, "Content-Type: text/html") assert.Contains(t, email.Data, fmt.Sprintf("Target path %q", path.Join("/subdir", testFileName))) assert.Contains(t, email.Data, "Name: test & chars,") } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleIDPLogin(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notify@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) lastReceivedEmail.reset() username := `test_'idp_'login` custom1 := `cust"oa"1` u := map[string]any{ "username": "{{.Name}}", "status": 1, "home_dir": filepath.Join(os.TempDir(), "{{.IDPFieldcustom1}}"), "permissions": map[string][]string{ "/": {dataprovider.PermAny}, }, } userTmpl, err := json.Marshal(u) require.NoError(t, err) a := map[string]any{ "username": "{{.Name}}", "status": 1, "permissions": []string{dataprovider.PermAdminAny}, } adminTmpl, err := json.Marshal(a) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeIDPAccountCheck, Options: dataprovider.BaseEventActionOptions{ IDPConfig: dataprovider.EventActionIDPAccountCheck{ Mode: 1, // create if not exists TemplateUser: string(userTmpl), TemplateAdmin: string(adminTmpl), }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.com"}, Subject: `"{{.Event}} {{.StatusString}}"`, Body: "{{.Name}} Custom field: {{.IDPFieldcustom1}}", }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rule IDP login", Status: 1, Trigger: dataprovider.EventTriggerIDPLogin, Conditions: dataprovider.EventConditions{ IDPLoginEvent: dataprovider.IDPLoginUser, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, // the rule is not sync and will be skipped }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, }, } rule1, resp, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err, string(resp)) customFields := map[string]any{ "custom1": custom1, } user, admin, err := common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginUser, Status: 1, }, &customFields) assert.Nil(t, user) assert.Nil(t, admin) assert.NoError(t, err) rule1.Actions[0].Options.ExecuteSync = true rule1, resp, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) assert.NoError(t, err, string(resp)) user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginUser, Status: 1, }, &customFields) if assert.NotNil(t, user) { assert.Equal(t, filepath.Join(os.TempDir(), custom1), user.GetHomeDir()) _, err = httpdtest.RemoveUser(*user, http.StatusOK) assert.NoError(t, err) } assert.Nil(t, admin) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, common.IDPLoginUser)) assert.Contains(t, email.Data, username) assert.Contains(t, email.Data, custom1) user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginAdmin, Status: 1, }, &customFields) assert.Nil(t, user) assert.Nil(t, admin) assert.NoError(t, err) rule1.Conditions.IDPLoginEvent = dataprovider.IDPLoginAny rule1.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, Order: 1, }, } rule1, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) assert.NoError(t, err) r2 := dataprovider.EventRule{ Name: "test email on IDP login", Status: 1, Trigger: dataprovider.EventTriggerIDPLogin, Conditions: dataprovider.EventConditions{ IDPLoginEvent: dataprovider.IDPLoginAdmin, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 1, }, }, } rule2, resp, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err, string(resp)) lastReceivedEmail.reset() user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginAdmin, Status: 1, }, &customFields) assert.Nil(t, user) if assert.NotNil(t, admin) { assert.Equal(t, 1, admin.Status) } assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, common.IDPLoginAdmin)) assert.Contains(t, email.Data, username) assert.Contains(t, email.Data, custom1) admin.Status = 0 _, _, err = httpdtest.UpdateAdmin(*admin, http.StatusOK) assert.NoError(t, err) user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginAdmin, Status: 1, }, &customFields) assert.Nil(t, user) if assert.NotNil(t, admin) { assert.Equal(t, 0, admin.Status) } assert.NoError(t, err) action1.Options.IDPConfig.Mode = 0 action1, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) assert.NoError(t, err) user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginAdmin, Status: 1, }, &customFields) assert.Nil(t, user) if assert.NotNil(t, admin) { assert.Equal(t, 1, admin.Status) } assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(*admin, http.StatusOK) assert.NoError(t, err) r3 := dataprovider.EventRule{ Name: "test rule2 IDP login", Status: 1, Trigger: dataprovider.EventTriggerIDPLogin, Conditions: dataprovider.EventConditions{ IDPLoginEvent: dataprovider.IDPLoginAny, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule3, resp, err := httpdtest.AddEventRule(r3, http.StatusCreated) assert.NoError(t, err, string(resp)) user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginAdmin, Status: 1, }, &customFields) assert.Nil(t, user) assert.Nil(t, admin) if assert.Error(t, err) { assert.Contains(t, err.Error(), "more than one account check action rules matches") } _, err = httpdtest.RemoveEventRule(rule3, http.StatusOK) assert.NoError(t, err) action1.Options.IDPConfig.TemplateAdmin = `{}` action1, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) assert.NoError(t, err) _, _, err = common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginAdmin, Status: 1, }, &customFields) assert.ErrorIs(t, err, util.ErrValidation) _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ Name: username, Event: common.IDPLoginAdmin, Status: 1, }, &customFields) assert.Nil(t, user) assert.Nil(t, admin) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleEmailField(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notify@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) lastReceivedEmail.reset() a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"{{.Email}}"}, Subject: `"{{.Event}}" from "{{.Name}}"`, Body: "Sample email body", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"failure@example.com"}, Subject: `"Failure`, Body: "{{.ErrorString}}", }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "r1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"mkdir"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, }, }, } r2 := dataprovider.EventRule{ Name: "test rule2", Status: 1, Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"add"}, Options: dataprovider.ConditionOptions{ ProviderObjects: []string{"user"}, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Email = "user@example.com" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, user.Email)) assert.Contains(t, email.Data, `Subject: "add" from "admin"`) // if we add a user without email the notification will fail lastReceivedEmail.reset() u1 := getTestUser() u1.Username += "_1" user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "failure@example.com")) assert.Contains(t, email.Data, `no recipient addresses set`) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() err = client.Mkdir(testFileName) assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, user.Email)) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "mkdir" from "%s"`, user.Username)) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleCertificate(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notify@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) lastReceivedEmail.reset() a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.com"}, Subject: `"{{.Event}} {{.StatusString}}"`, ContentType: 0, Body: "Domain: {{.Name}} Timestamp: {{.Timestamp}} {{.ErrorString}} Date time: {{.DateTime}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeFolderQuotaReset, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rule certificate", Status: 1, Trigger: dataprovider.EventTriggerCertificate, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) r2 := dataprovider.EventRule{ Name: "test rule 2", Status: 1, Trigger: dataprovider.EventTriggerCertificate, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, }, } rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err) renewalEvent := "Certificate renewal" common.HandleCertificateEvent(common.EventParams{ Name: "example.com", Timestamp: time.Now(), Status: 1, Event: renewalEvent, }) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, renewalEvent)) assert.Contains(t, email.Data, "Content-Type: text/plain") assert.Contains(t, email.Data, `Domain: example.com Timestamp`) lastReceivedEmail.reset() dateTime := time.Now() params := common.EventParams{ Name: "example.com", Timestamp: dateTime, Status: 2, Event: renewalEvent, } errRenew := errors.New("generic renew error") params.AddError(errRenew) common.HandleCertificateEvent(params) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email = lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "test@example.com")) assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s KO"`, renewalEvent)) assert.Contains(t, email.Data, `Domain: example.com Timestamp`) assert.Contains(t, email.Data, dateTime.UTC().Format("2006-01-02T15:04:05.000")) assert.Contains(t, email.Data, errRenew.Error()) _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) // ignored no more certificate rules common.HandleCertificateEvent(common.EventParams{ Name: "example.com", Timestamp: time.Now(), Status: 1, Event: renewalEvent, }) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleIPBlocked(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 err := common.Initialize(cfg, 0) assert.NoError(t, err) smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "action1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test3@example.com", "test4@example.com"}, Subject: `New "{{.Event}}"`, Body: "IP: {{.IP}} Timestamp: {{.Timestamp}}", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeFolderQuotaReset, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "test rule ip blocked", Status: 1, Trigger: dataprovider.EventTriggerIPBlocked, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) r2 := dataprovider.EventRule{ Name: "test rule 2", Status: 1, Trigger: dataprovider.EventTriggerIPBlocked, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, }, } rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err) u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) lastReceivedEmail.reset() time.Sleep(300 * time.Millisecond) assert.Empty(t, lastReceivedEmail.get().From, lastReceivedEmail.get().Data) for i := 0; i < 3; i++ { user.Password = "wrong_pwd" _, _, err = getSftpClient(user) assert.Error(t, err) } // the client is now banned user.Password = defaultPassword _, _, err = getSftpClient(user) assert.Error(t, err) // check the email notification assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 2) assert.True(t, slices.Contains(email.To, "test3@example.com")) assert.True(t, slices.Contains(email.To, "test4@example.com")) assert.Contains(t, email.Data, `Subject: New "IP Blocked"`) err = dataprovider.DeleteEventRule(rule1.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteEventRule(rule2.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteEventAction(action1.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteEventAction(action2.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestEventRuleRotateLog(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeRotateLogs, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"success@example.net"}, Subject: `OK`, Body: "OK action", }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"mkdir"}, Options: dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user.Username, }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, }, } rule1, resp, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err, string(resp)) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() err := client.Mkdir("just a test dir") assert.NoError(t, err) // just check that the action is executed assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.Contains(t, email.To, "success@example.net") } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRuleInactivityCheck(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeUserInactivityCheck, Options: dataprovider.BaseEventActionOptions{ UserInactivityConfig: dataprovider.EventActionUserInactivity{ DisableThreshold: 10, DeleteThreshold: 20, }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"success@example.net"}, Subject: `OK`, Body: "OK action", }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"mkdir"}, Options: dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: user.Username, }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, }, }, } rule1, resp, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err, string(resp)) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() err := client.Mkdir("just a test dir") assert.NoError(t, err) // just check that the action is executed assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.Contains(t, email.To, "success@example.net") } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestEventRulePasswordExpiration(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"failure@example.net"}, Subject: `Failure`, Body: "Failure action", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypePasswordExpirationCheck, Options: dataprovider.BaseEventActionOptions{ PwdExpirationConfig: dataprovider.EventActionPasswordExpiration{ Threshold: 10, }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) a3 := dataprovider.BaseEventAction{ Name: "a3", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"success@example.net"}, Subject: `OK`, Body: "OK action", }, }, } action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"mkdir"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 2, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, }, } rule1, resp, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err, string(resp)) dirName := "aTestDir" conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() err := client.Mkdir(dirName) assert.NoError(t, err) // the user has no password expiration, the check will be skipped and the ok action executed assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.Contains(t, email.To, "success@example.net") err = client.RemoveDirectory(dirName) assert.NoError(t, err) } user.Filters.PasswordExpiration = 20 _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() err := client.Mkdir(dirName) assert.NoError(t, err) // the passowrd is not about to expire, the check will be skipped and the ok action executed assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.Contains(t, email.To, "success@example.net") err = client.RemoveDirectory(dirName) assert.NoError(t, err) } user.Filters.PasswordExpiration = 5 _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() err := client.Mkdir(dirName) assert.NoError(t, err) // the passowrd is about to expire, the user has no email, the failure action will be executed assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.Contains(t, email.To, "failure@example.net") err = client.RemoveDirectory(dirName) assert.NoError(t, err) } // remove the success action rule1.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, } _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) assert.NoError(t, err) user.Email = "user@example.net" user.Filters.AdditionalEmails = []string{"additional@example.net"} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() lastReceivedEmail.reset() err := client.Mkdir(dirName) assert.NoError(t, err) // the passowrd expiration will be notified assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 1500*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 2) assert.Contains(t, email.To, user.Email) assert.Contains(t, email.To, user.Filters.AdditionalEmails[0]) assert.Contains(t, email.Data, "your SFTPGo password expires in 5 days") err = client.RemoveDirectory(dirName) assert.NoError(t, err) } _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestSyncUploadAction(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } uploadScriptPath := filepath.Join(os.TempDir(), "upload.sh") common.Config.Actions.ExecuteOn = []string{"upload"} common.Config.Actions.ExecuteSync = []string{"upload"} common.Config.Actions.Hook = uploadScriptPath u := getTestUser() u.QuotaFiles = 1000 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) movedFileName := "moved.dat" movedPath := filepath.Join(user.HomeDir, movedFileName) err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, "", 0), 0755) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() size := int64(32768) err = writeSFTPFileNoCheck(testFileName, size, client) assert.NoError(t, err) _, err = client.Stat(testFileName) assert.Error(t, err) info, err := client.Stat(movedFileName) if assert.NoError(t, err) { assert.Equal(t, size, info.Size()) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, size, user.UsedQuotaSize) // test some hook failure // the uploaded file is moved and the hook fails, it will be not removed from the quota err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, "", 1), 0755) assert.NoError(t, err) err = writeSFTPFileNoCheck(testFileName+"_1", size, client) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, size*2, user.UsedQuotaSize) // the uploaded file is not moved and the hook fails, the uploaded file will be deleted // and removed from the quota movedPath = filepath.Join(user.HomeDir, "missing dir", movedFileName) err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, "", 1), 0755) assert.NoError(t, err) err = writeSFTPFileNoCheck(testFileName+"_2", size, client) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, size*2, user.UsedQuotaSize) // overwrite an existing file _, err = client.Stat(movedFileName) assert.NoError(t, err) err = writeSFTPFileNoCheck(movedFileName, size, client) assert.Error(t, err) _, err = client.Stat(movedFileName) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, size, user.UsedQuotaSize) } err = os.Remove(uploadScriptPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.Actions.ExecuteOn = nil common.Config.Actions.ExecuteSync = nil common.Config.Actions.Hook = uploadScriptPath } func TestQuotaTrackDisabled(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.TrackQuota = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFile(testFileName, 32, client) assert.NoError(t, err) err = client.Rename(testFileName, testFileName+"1") assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestGetQuotaError(t *testing.T) { if dataprovider.GetProviderStatus().Driver == "memory" { t.Skip("this test is not available with the memory provider") } u := getTestUser() u.TotalDataTransfer = 2000 mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) vdirPath := "/vpath" f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, QuotaSize: 0, QuotaFiles: 10, }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFile(testFileName, 32, client) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = client.Rename(testFileName, path.Join(vdirPath, testFileName)) assert.Error(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestRetentionAPI(t *testing.T) { u := getTestUser() u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermOverwrite, dataprovider.PermDownload, dataprovider.PermCreateDirs, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) uploadPath := path.Join(testDir, testFileName) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir(testDir) assert.NoError(t, err) err = writeSFTPFile(uploadPath, 32, client) assert.NoError(t, err) folderRetention := []dataprovider.FolderRetention{ { Path: "/", Retention: 24, DeleteEmptyDirs: true, }, } check := common.RetentionCheck{ Folders: folderRetention, } c := common.RetentionChecks.Add(check, &user) assert.NotNil(t, c) err = c.Start() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) _, err = client.Stat(uploadPath) assert.NoError(t, err) err = client.Chtimes(uploadPath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) assert.NoError(t, err) err = c.Start() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) _, err = client.Stat(uploadPath) assert.ErrorIs(t, err, os.ErrNotExist) _, err = client.Stat(testDir) assert.ErrorIs(t, err, os.ErrNotExist) err = client.Mkdir(testDir) assert.NoError(t, err) err = writeSFTPFile(uploadPath, 32, client) assert.NoError(t, err) check.Folders[0].DeleteEmptyDirs = false err = client.Chtimes(uploadPath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) assert.NoError(t, err) c = common.RetentionChecks.Add(check, &user) assert.NotNil(t, c) err = c.Start() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) _, err = client.Stat(uploadPath) assert.ErrorIs(t, err, os.ErrNotExist) _, err = client.Stat(testDir) assert.NoError(t, err) err = writeSFTPFile(uploadPath, 32, client) assert.NoError(t, err) err = client.Chtimes(uploadPath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) assert.NoError(t, err) conn.Close() client.Close() } // remove delete permissions to the user, it will be automatically granted user.Permissions["/"+testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermCreateDirs, dataprovider.PermChtimes} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { innerUploadFilePath := path.Join("/"+testDir, testDir, testFileName) err = client.Mkdir(path.Join(testDir, testDir)) assert.NoError(t, err) err = writeSFTPFile(innerUploadFilePath, 32, client) assert.NoError(t, err) err = client.Chtimes(innerUploadFilePath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) assert.NoError(t, err) folderRetention := []dataprovider.FolderRetention{ { Path: "/missing", Retention: 24, }, { Path: "/" + testDir, Retention: 24, DeleteEmptyDirs: true, }, { Path: path.Dir(innerUploadFilePath), Retention: 0, }, } check := common.RetentionCheck{ Folders: folderRetention, } c := common.RetentionChecks.Add(check, &user) assert.NotNil(t, c) err = c.Start() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) _, err = client.Stat(uploadPath) assert.ErrorIs(t, err, os.ErrNotExist) _, err = client.Stat(innerUploadFilePath) assert.NoError(t, err) folderRetention = []dataprovider.FolderRetention{ { Path: "/" + testDir, Retention: 24, DeleteEmptyDirs: true, }, } check = common.RetentionCheck{ Folders: folderRetention, } c = common.RetentionChecks.Add(check, &user) assert.NotNil(t, c) err = c.Start() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) _, err = client.Stat(innerUploadFilePath) assert.ErrorIs(t, err, os.ErrNotExist) conn.Close() client.Close() } // finally test some errors removing files or folders if runtime.GOOS != osWindows { dirPath := filepath.Join(user.HomeDir, "adir", "sub") err := os.MkdirAll(dirPath, os.ModePerm) assert.NoError(t, err) filePath := filepath.Join(dirPath, "f.dat") err = os.WriteFile(filePath, nil, os.ModePerm) assert.NoError(t, err) err = os.Chtimes(filePath, time.Now().Add(-72*time.Hour), time.Now().Add(-72*time.Hour)) assert.NoError(t, err) err = os.Chmod(dirPath, 0001) assert.NoError(t, err) folderRetention := []dataprovider.FolderRetention{ { Path: "/adir", Retention: 24, DeleteEmptyDirs: true, }, } check := common.RetentionCheck{ Folders: folderRetention, } c := common.RetentionChecks.Add(check, &user) assert.NotNil(t, c) err = c.Start() assert.ErrorIs(t, err, os.ErrPermission) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) err = os.Chmod(dirPath, 0555) assert.NoError(t, err) c = common.RetentionChecks.Add(check, &user) assert.NotNil(t, c) err = c.Start() assert.ErrorIs(t, err, os.ErrPermission) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) err = os.Chmod(dirPath, os.ModePerm) assert.NoError(t, err) check = common.RetentionCheck{ Folders: folderRetention, } c = common.RetentionChecks.Add(check, &user) assert.NotNil(t, c) err = c.Start() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.NoDirExists(t, dirPath) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1*time.Second, 50*time.Millisecond) } func TestPerUserTransferLimits(t *testing.T) { oldMaxPerHostConns := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 2 u := getTestUser() u.UploadBandwidth = 32 user, _, err := httpdtest.AddUser(u, http.StatusCreated) if !assert.NoError(t, err) { printLatestLogs(20) } conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() var wg sync.WaitGroup numErrors := 0 for i := 0; i <= 2; i++ { wg.Add(1) go func(counter int) { defer wg.Done() time.Sleep(20 * time.Millisecond) err := writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client) if err != nil { numErrors++ } }(i) } wg.Wait() assert.Equal(t, 1, numErrors) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.MaxPerHostConnections = oldMaxPerHostConns } func TestMaxSessionsSameConnection(t *testing.T) { u := getTestUser() u.UploadBandwidth = 32 u.MaxSessions = 2 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() var wg sync.WaitGroup numErrors := 0 for i := 0; i <= 2; i++ { wg.Add(1) go func(counter int) { defer wg.Done() var err error if counter < 2 { err = writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client) } else { // wait for the transfers to start time.Sleep(50 * time.Millisecond) _, _, err = getSftpClient(user) } if err != nil { numErrors++ } }(i) } wg.Wait() assert.Equal(t, 1, numErrors) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRenameDir(t *testing.T) { u := getTestUser() testDir := "/dir-to-rename" u.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir(testDir) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, testFileName), 32, client) assert.NoError(t, err) err = client.Rename(testDir, testDir+"_rename") assert.ErrorIs(t, err, os.ErrPermission) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestBuiltinKeyboardInteractiveAuthentication(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) authMethods := []ssh.AuthMethod{ ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { return []string{defaultPassword}, nil }), } conn, client, err := getCustomAuthSftpClient(user, authMethods) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) } // add multi-factor authentication configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) assert.NoError(t, err) passwordAsked := false passcodeAsked := false authMethods = []ssh.AuthMethod{ ssh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) { var answers []string if strings.HasPrefix(questions[0], "Password") { answers = append(answers, defaultPassword) passwordAsked = true } else { answers = append(answers, passcode) passcodeAsked = true } return answers, nil }), } conn, client, err = getCustomAuthSftpClient(user, authMethods) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) } assert.True(t, passwordAsked) assert.True(t, passcodeAsked) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMultiStepBuiltinKeyboardAuth(t *testing.T) { u := getTestUser() u.PublicKeys = []string{testPubKey} u.Filters.DeniedLoginMethods = []string{ dataprovider.SSHLoginMethodPublicKey, dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) assert.NoError(t, err) // public key + password authMethods := []ssh.AuthMethod{ ssh.PublicKeys(signer), ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { return []string{defaultPassword}, nil }), } conn, client, err := getCustomAuthSftpClient(user, authMethods) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) } // add multi-factor authentication configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) assert.NoError(t, err) // public key + passcode authMethods = []ssh.AuthMethod{ ssh.PublicKeys(signer), ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { return []string{passcode}, nil }), } conn, client, err = getCustomAuthSftpClient(user, authMethods) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRenameSymlink(t *testing.T) { u := getTestUser() testDir := "/dir-no-create-links" otherDir := "otherdir" u.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermCreateDirs} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir(otherDir) assert.NoError(t, err) err = client.Symlink(otherDir, otherDir+".link") assert.NoError(t, err) err = client.Rename(otherDir+".link", path.Join(testDir, "symlink")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(otherDir+".link", "allowed_link") assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSplittedDeletePerms(t *testing.T) { u := getTestUser() u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDeleteDirs, dataprovider.PermCreateDirs} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) err = client.Remove(testFileName) assert.Error(t, err) err = client.Mkdir(testDir) assert.NoError(t, err) err = client.RemoveDirectory(testDir) assert.NoError(t, err) } u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDeleteFiles, dataprovider.PermCreateDirs, dataprovider.PermOverwrite} _, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) err = client.Mkdir(testDir) assert.NoError(t, err) err = client.RemoveDirectory(testDir) assert.Error(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSplittedRenamePerms(t *testing.T) { u := getTestUser() u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermRenameDirs, dataprovider.PermCreateDirs} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) err = client.Mkdir(testDir) assert.NoError(t, err) err = client.Rename(testFileName, testFileName+"_renamed") assert.Error(t, err) err = client.Rename(testDir, testDir+"_renamed") assert.NoError(t, err) } u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermRenameFiles, dataprovider.PermCreateDirs, dataprovider.PermOverwrite} _, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) err = client.Mkdir(testDir) assert.NoError(t, err) err = client.Rename(testFileName, testFileName+"_renamed") assert.NoError(t, err) err = client.Rename(testDir, testDir+"_renamed") assert.Error(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSFTPLoopError(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 2525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) user1 := getTestUser() user2 := getTestUser() user1.Username += "1" user2.Username += "2" // user1 is a local account with a virtual SFTP folder to user2 // user2 has user1 as SFTP fs f := vfs.BaseVirtualFolder{ Name: "sftp", FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user2.Username, }, Password: kms.NewPlainSecret(defaultPassword), }, }, } folder, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user1.VirtualFolders = append(user1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder.Name, }, VirtualPath: "/vdir", }) user2.FsConfig.Provider = sdk.SFTPFilesystemProvider user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user1.Username, }, Password: kms.NewPlainSecret(defaultPassword), } user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) assert.NoError(t, err, string(resp)) user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) assert.NoError(t, err, string(resp)) a1 := dataprovider.BaseEventAction{ Name: "a1", Type: dataprovider.ActionTypeUserQuotaReset, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) a2 := dataprovider.BaseEventAction{ Name: "a2", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"failure@example.com"}, Subject: `Failed action"`, Body: "Test body", }, }, } action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "rule1", Status: 1, Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"update"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 2, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) lastReceivedEmail.reset() _, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") assert.NoError(t, err) assert.Eventually(t, func() bool { return lastReceivedEmail.get().From != "" }, 3000*time.Millisecond, 100*time.Millisecond) email := lastReceivedEmail.get() assert.Len(t, email.To, 1) assert.True(t, slices.Contains(email.To, "failure@example.com")) assert.Contains(t, email.Data, `Subject: Failed action`) user1.VirtualFolders[0].FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) user2.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) conn := common.NewBaseConnection("", common.ProtocolWebDAV, "", "", user1) _, _, err = conn.GetFsAndResolvedPath(user1.VirtualFolders[0].VirtualPath) assert.ErrorIs(t, err, os.ErrPermission) conn = common.NewBaseConnection("", common.ProtocolSFTP, "", "", user1) _, _, err = conn.GetFsAndResolvedPath(user1.VirtualFolders[0].VirtualPath) assert.Error(t, err) conn = common.NewBaseConnection("", common.ProtocolFTP, "", "", user1) _, _, err = conn.GetFsAndResolvedPath(user1.VirtualFolders[0].VirtualPath) assert.Error(t, err) _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user2.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestNonLocalCrossRename(t *testing.T) { baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err, string(resp)) u := getTestUser() u.HomeDir += "_folders" u.Username += "_folders" mappedPathSFTP := filepath.Join(os.TempDir(), "sftp") folderNameSFTP := filepath.Base(mappedPathSFTP) vdirSFTPPath := "/vdir/sftp" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameSFTP, }, VirtualPath: vdirSFTPPath, }) mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") folderNameCrypt := filepath.Base(mappedPathCrypt) vdirCryptPath := "/vdir/crypt" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameCrypt, }, VirtualPath: vdirCryptPath, }) f1 := vfs.BaseVirtualFolder{ Name: folderNameSFTP, FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: baseUser.Username, }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderNameCrypt, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, MappedPath: mappedPathCrypt, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirSFTPPath, testFileName), 8192, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), 16384, client) assert.NoError(t, err) err = client.Rename(path.Join(vdirSFTPPath, testFileName), path.Join(vdirCryptPath, testFileName+".rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirSFTPPath, testFileName+".rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(testFileName, path.Join(vdirCryptPath, testFileName+".rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(testFileName, path.Join(vdirSFTPPath, testFileName+".rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join(vdirSFTPPath, testFileName), testFileName+".rename") assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join(vdirCryptPath, testFileName), testFileName+".rename") assert.ErrorIs(t, err, os.ErrPermission) // rename on local fs or on the same folder must work err = client.Rename(testFileName, testFileName+".rename") assert.NoError(t, err) err = client.Rename(path.Join(vdirSFTPPath, testFileName), path.Join(vdirSFTPPath, testFileName+"_rename")) assert.NoError(t, err) err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirCryptPath, testFileName+"_rename")) assert.NoError(t, err) // renaming a virtual folder is not allowed err = client.Rename(vdirSFTPPath, vdirSFTPPath+"_rename") assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(vdirCryptPath, vdirCryptPath+"_rename") assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(vdirCryptPath, path.Join(vdirCryptPath, "rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Mkdir(path.Join(vdirCryptPath, "subcryptdir")) assert.NoError(t, err) err = client.Rename(path.Join(vdirCryptPath, "subcryptdir"), vdirCryptPath) assert.ErrorIs(t, err, os.ErrPermission) // renaming root folder is not allowed err = client.Rename("/", "new_name") assert.ErrorIs(t, err, os.ErrPermission) // renaming a path to a virtual folder is not allowed err = client.Rename("/vdir", "new_vdir") if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameSFTP}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPathCrypt) assert.NoError(t, err) err = os.RemoveAll(mappedPathSFTP) assert.NoError(t, err) } func TestNonLocalCrossRenameNonLocalBaseUser(t *testing.T) { baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err, string(resp)) u := getTestSFTPUser() mappedPathLocal := filepath.Join(os.TempDir(), "local") folderNameLocal := filepath.Base(mappedPathLocal) vdirLocalPath := "/vdir/local" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameLocal, }, VirtualPath: vdirLocalPath, }) mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") folderNameCrypt := filepath.Base(mappedPathCrypt) vdirCryptPath := "/vdir/crypt" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameCrypt, }, VirtualPath: vdirCryptPath, }) f1 := vfs.BaseVirtualFolder{ Name: folderNameLocal, MappedPath: mappedPathLocal, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderNameCrypt, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, MappedPath: mappedPathCrypt, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) err = writeSFTPFile(testFileName, 4096, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirLocalPath, testFileName), 8192, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), 16384, client) assert.NoError(t, err) err = client.Rename(path.Join(vdirLocalPath, testFileName), path.Join(vdirCryptPath, testFileName+".rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirLocalPath, testFileName+".rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(testFileName, path.Join(vdirCryptPath, testFileName+".rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(testFileName, path.Join(vdirLocalPath, testFileName+".rename")) assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join(vdirLocalPath, testFileName), testFileName+".rename") assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(path.Join(vdirCryptPath, testFileName), testFileName+".rename") assert.ErrorIs(t, err, os.ErrPermission) // rename on local fs or on the same folder must work err = client.Rename(testFileName, testFileName+".rename") assert.NoError(t, err) err = client.Rename(path.Join(vdirLocalPath, testFileName), path.Join(vdirLocalPath, testFileName+"_rename")) assert.NoError(t, err) err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirCryptPath, testFileName+"_rename")) assert.NoError(t, err) // renaming a virtual folder is not allowed err = client.Rename(vdirLocalPath, vdirLocalPath+"_rename") assert.ErrorIs(t, err, os.ErrPermission) err = client.Rename(vdirCryptPath, vdirCryptPath+"_rename") assert.ErrorIs(t, err, os.ErrPermission) // renaming root folder is not allowed err = client.Rename("/", "new_name") assert.ErrorIs(t, err, os.ErrPermission) // renaming a path to a virtual folder is not allowed err = client.Rename("/vdir", "new_vdir") if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameLocal}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPathCrypt) assert.NoError(t, err) err = os.RemoveAll(mappedPathLocal) assert.NoError(t, err) } func TestCopyAndRemoveSSHCommands(t *testing.T) { u := getTestUser() u.QuotaFiles = 1000 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() fileSize := int64(32) err = writeSFTPFile(testFileName, fileSize, client) assert.NoError(t, err) testFileNameCopy := testFileName + "_copy" out, err := runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileNameCopy), user) assert.NoError(t, err, string(out)) // the resolved destination path match the source path out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, path.Dir(testFileName)), user) assert.Error(t, err, string(out)) info, err := client.Stat(testFileNameCopy) if assert.NoError(t, err) { assert.Equal(t, fileSize, info.Size()) } testDir := "test dir" err = client.Mkdir(testDir) assert.NoError(t, err) out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s '%s'`, testFileName, testDir), user) assert.NoError(t, err, string(out)) info, err = client.Stat(path.Join(testDir, testFileName)) if assert.NoError(t, err) { assert.Equal(t, fileSize, info.Size()) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3*fileSize, user.UsedQuotaSize) assert.Equal(t, 3, user.UsedQuotaFiles) out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %s", testFileNameCopy), user) assert.NoError(t, err, string(out)) out, err = runSSHCommand(fmt.Sprintf(`sftpgo-remove '%s'`, testDir), user) assert.NoError(t, err, string(out)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, fileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) _, err = client.Stat(testFileNameCopy) assert.ErrorIs(t, err, os.ErrNotExist) // create a dir tree dir1 := "dir1" dir2 := "dir 2" err = client.MkdirAll(path.Join(dir1, dir2)) assert.NoError(t, err) toCreate := []string{ path.Join(dir1, testFileName), path.Join(dir1, dir2, testFileName), } for _, p := range toCreate { err = writeSFTPFile(p, fileSize, client) assert.NoError(t, err) } // create a symlink, copying a symlink is not supported err = client.Symlink(path.Join("/", dir1, testFileName), path.Join("/", dir1, testFileName+"_link")) assert.NoError(t, err) out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", dir1, testFileName+"_link"), path.Join("/", testFileName+"_link")), user) assert.Error(t, err, string(out)) // copying a dir inside itself should fail out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", dir1), path.Join("/", dir1, "sub")), user) assert.Error(t, err, string(out)) // copy source and dest must differ out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", dir1), path.Join("/", dir1)), user) assert.Error(t, err, string(out)) // copy a missing file/dir should fail out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", "missing_entry"), path.Join("/", dir1)), user) assert.Error(t, err, string(out)) // try to overwrite a file with a dir out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", dir1), testFileName), user) assert.Error(t, err, string(out)) out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s "%s"`, dir1, dir2), user) assert.NoError(t, err, string(out)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 5*fileSize, user.UsedQuotaSize) assert.Equal(t, 5, user.UsedQuotaFiles) // copy again, quota must remain unchanged out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s/ "%s"`, dir1, dir2), user) assert.NoError(t, err, string(out)) _, err = client.Stat(dir2) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 5*fileSize, user.UsedQuotaSize) assert.Equal(t, 5, user.UsedQuotaFiles) // now copy inside target out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s "%s"`, dir1, dir2), user) assert.NoError(t, err, string(out)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 7*fileSize, user.UsedQuotaSize) assert.Equal(t, 7, user.UsedQuotaFiles) for _, p := range []string{dir1, dir2} { out, err = runSSHCommand(fmt.Sprintf(`sftpgo-remove "%s"`, p), user) assert.NoError(t, err, string(out)) _, err = client.Stat(p) assert.ErrorIs(t, err, os.ErrNotExist) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, fileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) // test quota errors user.QuotaFiles = 1 _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // quota files exceeded out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileNameCopy), user) assert.Error(t, err, string(out)) user.QuotaFiles = 1000 user.QuotaSize = fileSize + 1 _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // quota size exceeded after the copy out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileNameCopy), user) assert.Error(t, err, string(out)) user.QuotaSize = fileSize - 1 _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // quota size exceeded out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileNameCopy), user) assert.Error(t, err, string(out)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestCopyAndRemovePermissions(t *testing.T) { u := getTestUser() restrictedPath := "/dir/path" patternFilterPath := "/patterns" u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: patternFilterPath, DeniedPatterns: []string{"*.dat"}, }, } u.Permissions[restrictedPath] = []string{} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.MkdirAll(restrictedPath) assert.NoError(t, err) err = client.MkdirAll(patternFilterPath) assert.NoError(t, err) err = writeSFTPFile(testFileName, 100, client) assert.NoError(t, err) // getting file writer will fail out, err := runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, restrictedPath), user) assert.Error(t, err, string(out)) // file pattern not allowed out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, patternFilterPath), user) assert.Error(t, err, string(out)) testDir := path.Join("/", path.Base(restrictedPath)) err = client.Mkdir(testDir) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, testFileName), 100, client) assert.NoError(t, err) // creating target dir will fail out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s/`, testDir, restrictedPath), user) assert.Error(t, err, string(out)) // get dir contents will fail out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s /`, restrictedPath), user) assert.Error(t, err, string(out)) // get dir contents will fail out, err = runSSHCommand(fmt.Sprintf(`sftpgo-remove %s`, restrictedPath), user) assert.Error(t, err, string(out)) // give list dir permissions and retry, now delete will fail user.Permissions[restrictedPath] = []string{dataprovider.PermListItems, dataprovider.PermUpload} user.Permissions[testDir] = []string{dataprovider.PermListItems} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // no copy permission out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, restrictedPath), user) assert.Error(t, err, string(out)) user.Permissions[restrictedPath] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermCopy} user.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermCopy} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, restrictedPath), user) assert.NoError(t, err, string(out)) // overwrite will fail, no permission out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, restrictedPath), user) assert.Error(t, err, string(out)) out, err = runSSHCommand(fmt.Sprintf(`sftpgo-remove %s`, restrictedPath), user) assert.Error(t, err, string(out)) // try to copy a file from testDir, we have only list permissions so getFileReader will fail out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, path.Join(testDir, testFileName), testFileName+".copy"), user) assert.Error(t, err, string(out)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestCrossFoldersCopy(t *testing.T) { baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err, string(resp)) u := getTestUser() u.Username += "_1" u.HomeDir = filepath.Join(os.TempDir(), u.Username) u.QuotaFiles = 1000 mappedPath1 := filepath.Join(os.TempDir(), "mapped1") folderName1 := filepath.Base(mappedPath1) vpath1 := "/vdirs/vdir1" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vpath1, QuotaSize: -1, QuotaFiles: -1, }) mappedPath2 := filepath.Join(os.TempDir(), "mapped1", "dir", "mapped2") folderName2 := filepath.Base(mappedPath2) vpath2 := "/vdirs/vdir2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vpath2, QuotaSize: -1, QuotaFiles: -1, }) mappedPath3 := filepath.Join(os.TempDir(), "mapped3") folderName3 := filepath.Base(mappedPath3) vpath3 := "/vdirs/vdir3" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName3, }, VirtualPath: vpath3, QuotaSize: -1, QuotaFiles: -1, }) mappedPath4 := filepath.Join(os.TempDir(), "mapped4") folderName4 := filepath.Base(mappedPath4) vpath4 := "/vdirs/vdir4" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName4, }, VirtualPath: vpath4, QuotaSize: -1, QuotaFiles: -1, }) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) f3 := vfs.BaseVirtualFolder{ Name: folderName3, MappedPath: mappedPath3, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) f4 := vfs.BaseVirtualFolder{ Name: folderName4, MappedPath: mappedPath4, FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: baseUser.Username, }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f4, http.StatusCreated) assert.NoError(t, err) user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) conn, client, err := getSftpClient(user) if assert.NoError(t, err) { defer conn.Close() defer client.Close() baseFileSize := int64(100) err = writeSFTPFile(path.Join(vpath1, testFileName), baseFileSize+1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vpath2, testFileName), baseFileSize+2, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vpath3, testFileName), baseFileSize+3, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vpath4, testFileName), baseFileSize+4, client) assert.NoError(t, err) // cannot remove a directory with virtual folders inside out, err := runSSHCommand(fmt.Sprintf(`sftpgo-remove %s`, path.Dir(vpath1)), user) assert.Error(t, err, string(out)) // copy across virtual folders copyDir := "/copy" out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s/`, path.Dir(vpath1), copyDir), user) assert.NoError(t, err, string(out)) // check the copy info, err := client.Stat(path.Join(copyDir, vpath1, testFileName)) if assert.NoError(t, err) { assert.Equal(t, baseFileSize+1, info.Size()) } info, err = client.Stat(path.Join(copyDir, vpath2, testFileName)) if assert.NoError(t, err) { assert.Equal(t, baseFileSize+2, info.Size()) } info, err = client.Stat(path.Join(copyDir, vpath3, testFileName)) if assert.NoError(t, err) { assert.Equal(t, baseFileSize+3, info.Size()) } info, err = client.Stat(path.Join(copyDir, vpath4, testFileName)) if assert.NoError(t, err) { assert.Equal(t, baseFileSize+4, info.Size()) } // nested fs paths out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, vpath1, vpath2), user) assert.Error(t, err, string(out)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) for _, folderName := range []string{folderName1, folderName2, folderName3, folderName4} { _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(filepath.Join(os.TempDir(), folderName)) assert.NoError(t, err) } } func TestHTTPFs(t *testing.T) { u := getTestUserWithHTTPFs() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) conn := common.NewBaseConnection(xid.New().String(), common.ProtocolFTP, "", "", user) err = conn.CreateDir(httpFsWellKnowDir, false) assert.NoError(t, err) err = os.WriteFile(filepath.Join(os.TempDir(), "httpfs", defaultHTTPFsUsername, httpFsWellKnowDir, "file.txt"), []byte("data"), 0666) assert.NoError(t, err) err = conn.Copy(httpFsWellKnowDir, httpFsWellKnowDir+"_copy") assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestProxyProtocol(t *testing.T) { resp, err := httpclient.Get(fmt.Sprintf("http://%v", httpProxyAddr)) if !assert.Error(t, err) { resp.Body.Close() } } func TestSetProtocol(t *testing.T) { conn := common.NewBaseConnection("id", "sshd_exec", "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}}) conn.SetProtocol(common.ProtocolSCP) require.Equal(t, "SCP_id", conn.GetID()) } func TestGetFsError(t *testing.T) { u := getTestUser() u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") conn := common.NewBaseConnection("", common.ProtocolFTP, "", "", u) _, _, err := conn.GetFsAndResolvedPath("/vpath") assert.Error(t, err) } func waitTCPListening(address string) { for { conn, err := net.Dial("tcp", address) if err != nil { logger.WarnToConsole("tcp server %v not listening: %v", address, err) time.Sleep(100 * time.Millisecond) continue } logger.InfoToConsole("tcp server %v now listening", address) conn.Close() break } } func checkBasicSFTP(client *sftp.Client) error { _, err := client.Getwd() if err != nil { return err } _, err = client.ReadDir(".") return err } func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMethod) (*ssh.Client, *sftp.Client, error) { var sftpClient *sftp.Client config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Auth: authMethods, Timeout: 5 * time.Second, } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { return conn, sftpClient, err } sftpClient, err = sftp.NewClient(conn) if err != nil { conn.Close() } return conn, sftpClient, err } func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { var sftpClient *sftp.Client config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } if user.Password != "" { config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} } else { config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { return conn, sftpClient, err } sftpClient, err = sftp.NewClient(conn) if err != nil { conn.Close() } return conn, sftpClient, err } func runSSHCommand(command string, user dataprovider.User) ([]byte, error) { var sshSession *ssh.Session var output []byte config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } if user.Password != "" { config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} } else { config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { return output, err } defer conn.Close() sshSession, err = conn.NewSession() if err != nil { return output, err } var stdout, stderr bytes.Buffer sshSession.Stdout = &stdout sshSession.Stderr = &stderr err = sshSession.Run(command) if err != nil { return nil, fmt.Errorf("failed to run command %v: %v", command, stderr.Bytes()) } return stdout.Bytes(), err } func getWebDavClient(user dataprovider.User) *gowebdav.Client { rootPath := fmt.Sprintf("http://localhost:%d/", webDavServerPort) pwd := defaultPassword if user.Password != "" { pwd = user.Password } client := gowebdav.NewClient(rootPath, user.Username, pwd) client.SetTimeout(10 * time.Second) return client } func getTestUser() dataprovider.User { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: defaultUsername, Password: defaultPassword, HomeDir: filepath.Join(homeBasePath, defaultUsername), Status: 1, ExpirationDate: 0, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = allPerms return user } func getTestSFTPUser() dataprovider.User { u := getTestUser() u.Username = defaultSFTPUsername u.FsConfig.Provider = sdk.SFTPFilesystemProvider u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) return u } func getCryptFsUser() dataprovider.User { u := getTestUser() u.Username += "_crypt" u.FsConfig.Provider = sdk.CryptedFilesystemProvider u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword) return u } func getTestUserWithHTTPFs() dataprovider.User { u := getTestUser() u.FsConfig.Provider = sdk.HTTPFilesystemProvider u.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: fmt.Sprintf("http://127.0.0.1:%d/api/v1", httpFsPort), Username: defaultHTTPFsUsername, }, } return u } func writeSFTPFile(name string, size int64, client *sftp.Client) error { err := writeSFTPFileNoCheck(name, size, client) if err != nil { return err } info, err := client.Stat(name) if err != nil { return err } if info.Size() != size { return fmt.Errorf("file size mismatch, wanted %v, actual %v", size, info.Size()) } return nil } func writeSFTPFileNoCheck(name string, size int64, client *sftp.Client) error { content := make([]byte, size) _, err := rand.Read(content) if err != nil { return err } f, err := client.Create(name) if err != nil { return err } _, err = io.Copy(f, bytes.NewBuffer(content)) if err != nil { f.Close() return err } return f.Close() } func getUploadScriptEnvContent(envVar string) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte(fmt.Sprintf("if [ -z \"$%s\" ]\n", envVar))...) content = append(content, []byte("then\n")...) content = append(content, []byte(" exit 1\n")...) content = append(content, []byte("else\n")...) content = append(content, []byte(" exit 0\n")...) content = append(content, []byte("fi\n")...) return content } func getUploadScriptContent(movedPath, logFilePath string, exitStatus int) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte("sleep 1\n")...) if logFilePath != "" { content = append(content, []byte(fmt.Sprintf("echo $@ > %v\n", logFilePath))...) } content = append(content, []byte(fmt.Sprintf("mv ${SFTPGO_ACTION_PATH} %v\n", movedPath))...) content = append(content, []byte(fmt.Sprintf("exit %d", exitStatus))...) return content } func getSaveProviderObjectScriptContent(outFilePath string, exitStatus int) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte(fmt.Sprintf("echo ${SFTPGO_OBJECT_DATA} > %v\n", outFilePath))...) content = append(content, []byte(fmt.Sprintf("exit %d", exitStatus))...) return content } func generateTOTPPasscode(secret string, algo otp.Algorithm) (string, error) { return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ Period: 30, Skew: 1, Digits: otp.DigitsSix, Algorithm: algo, }) } func isDbDefenderSupported() bool { // SQLite shares the implementation with other SQL-based provider but it makes no sense // to use it outside test cases switch dataprovider.GetProviderStatus().Driver { case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: return true default: return false } } func getEncryptedFileSize(size int64) (int64, error) { encSize, err := sio.EncryptedSize(uint64(size)) return int64(encSize) + 33, err } func printLatestLogs(maxNumberOfLines int) { var lines []string f, err := os.Open(logFilePath) if err != nil { return } defer f.Close() scanner := bufio.NewScanner(f) for scanner.Scan() { lines = append(lines, scanner.Text()+"\r\n") for len(lines) > maxNumberOfLines { lines = lines[1:] } } if scanner.Err() != nil { logger.WarnToConsole("Unable to print latest logs: %v", scanner.Err()) return } for _, line := range lines { logger.DebugToConsole("%s", line) } } type receivedEmail struct { sync.RWMutex From string To []string Data string } func (e *receivedEmail) set(from string, to []string, data []byte) { e.Lock() defer e.Unlock() e.From = from e.To = to e.Data = strings.ReplaceAll(string(data), "=\r\n", "") } func (e *receivedEmail) reset() { e.Lock() defer e.Unlock() e.From = "" e.To = nil e.Data = "" } func (e *receivedEmail) get() receivedEmail { e.RLock() defer e.RUnlock() return receivedEmail{ From: e.From, To: e.To, Data: e.Data, } } func startHTTPFs() { go func() { readdirCallback := func(name string) []os.FileInfo { if name == httpFsWellKnowDir { return []os.FileInfo{vfs.NewFileInfo("ghost.txt", false, 0, time.Unix(0, 0), false)} } return nil } callbacks := &httpdtest.HTTPFsCallbacks{ Readdir: readdirCallback, } if err := httpdtest.StartTestHTTPFs(httpFsPort, callbacks); err != nil { logger.ErrorToConsole("could not start HTTPfs test server: %v", err) os.Exit(1) } }() waitTCPListening(fmt.Sprintf(":%d", httpFsPort)) } ================================================ FILE: internal/common/ratelimiter.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "errors" "fmt" "slices" "sort" "sync" "sync/atomic" "time" "golang.org/x/time/rate" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( errNoBucket = errors.New("no bucket found") errReserve = errors.New("unable to reserve token") rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP} ) // RateLimiterType defines the supported rate limiters types type RateLimiterType int // Supported rate limiter types const ( rateLimiterTypeGlobal RateLimiterType = iota + 1 rateLimiterTypeSource ) // RateLimiterConfig defines the configuration for a rate limiter type RateLimiterConfig struct { // Average defines the maximum rate allowed. 0 means disabled Average int64 `json:"average" mapstructure:"average"` // Period defines the period as milliseconds. Default: 1000 (1 second). // The rate is actually defined by dividing average by period. // So for a rate below 1 req/s, one needs to define a period larger than a second. Period int64 `json:"period" mapstructure:"period"` // Burst is the maximum number of requests allowed to go through in the // same arbitrarily small period of time. Default: 1. Burst int `json:"burst" mapstructure:"burst"` // Type defines the rate limiter type: // - rateLimiterTypeGlobal is a global rate limiter independent from the source // - rateLimiterTypeSource is a per-source rate limiter Type int `json:"type" mapstructure:"type"` // Protocols defines the protocols for this rate limiter. // Available protocols are: "SFTP", "FTP", "DAV". // A rate limiter with no protocols defined is disabled Protocols []string `json:"protocols" mapstructure:"protocols"` // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter, // a new defender event will be generated GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"` // The number of per-ip rate limiters kept in memory will vary between the // soft and hard limit EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"` EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"` } func (r *RateLimiterConfig) isEnabled() bool { return r.Average > 0 && len(r.Protocols) > 0 } func (r *RateLimiterConfig) validate() error { if r.Burst < 1 { return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst) } if r.Period < 100 { return fmt.Errorf("invalid period %v. It must be >= 100", r.Period) } if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) { return fmt.Errorf("invalid type %v", r.Type) } if r.Type != int(rateLimiterTypeGlobal) { if r.EntriesSoftLimit <= 0 { return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit) } if r.EntriesHardLimit <= r.EntriesSoftLimit { return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit) } } r.Protocols = util.RemoveDuplicates(r.Protocols, true) for _, protocol := range r.Protocols { if !slices.Contains(rateLimiterProtocolValues, protocol) { return fmt.Errorf("invalid protocol %q", protocol) } } return nil } func (r *RateLimiterConfig) getLimiter() *rateLimiter { limiter := &rateLimiter{ burst: r.Burst, globalBucket: nil, generateDefenderEvents: r.GenerateDefenderEvents, } var maxDelay time.Duration period := time.Duration(r.Period) * time.Millisecond rtl := float64(r.Average*int64(time.Second)) / float64(period) limiter.rate = rate.Limit(rtl) if rtl < 1 { maxDelay = period / 2 } else { maxDelay = time.Second / (time.Duration(rtl) * 2) } if maxDelay > 10*time.Second { maxDelay = 10 * time.Second } limiter.maxDelay = maxDelay limiter.buckets = sourceBuckets{ buckets: make(map[string]sourceRateLimiter), hardLimit: r.EntriesHardLimit, softLimit: r.EntriesSoftLimit, } if r.Type != int(rateLimiterTypeSource) { limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst) } return limiter } // RateLimiter defines a rate limiter type rateLimiter struct { rate rate.Limit burst int maxDelay time.Duration globalBucket *rate.Limiter buckets sourceBuckets generateDefenderEvents bool } // Wait blocks until the limit allows one event to happen // or returns an error if the time to wait exceeds the max // allowed delay func (rl *rateLimiter) Wait(source, protocol string) (time.Duration, error) { var res *rate.Reservation if rl.globalBucket != nil { res = rl.globalBucket.Reserve() } else { var err error res, err = rl.buckets.reserve(source) if err != nil { rateLimiter := rate.NewLimiter(rl.rate, rl.burst) res = rl.buckets.addAndReserve(rateLimiter, source) } } if !res.OK() { return 0, errReserve } delay := res.Delay() if delay > rl.maxDelay { res.Cancel() if rl.generateDefenderEvents && rl.globalBucket == nil { AddDefenderEvent(source, protocol, HostEventLimitExceeded) } return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay) } time.Sleep(delay) return 0, nil } type sourceRateLimiter struct { lastActivity *atomic.Int64 bucket *rate.Limiter } func (s *sourceRateLimiter) updateLastActivity() { s.lastActivity.Store(time.Now().UnixNano()) } func (s *sourceRateLimiter) getLastActivity() int64 { return s.lastActivity.Load() } type sourceBuckets struct { sync.RWMutex buckets map[string]sourceRateLimiter hardLimit int softLimit int } func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) { b.RLock() defer b.RUnlock() if src, ok := b.buckets[source]; ok { src.updateLastActivity() return src.bucket.Reserve(), nil } return nil, errNoBucket } func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation { b.Lock() defer b.Unlock() b.cleanup() src := sourceRateLimiter{ lastActivity: new(atomic.Int64), bucket: r, } src.updateLastActivity() b.buckets[source] = src return src.bucket.Reserve() } func (b *sourceBuckets) cleanup() { if len(b.buckets) >= b.hardLimit { numToRemove := len(b.buckets) - b.softLimit kvList := make(kvList, 0, len(b.buckets)) for k, v := range b.buckets { kvList = append(kvList, kv{ Key: k, Value: v.getLastActivity(), }) } sort.Sort(kvList) for idx, kv := range kvList { if idx >= numToRemove { break } delete(b.buckets, kv.Key) } } } ================================================ FILE: internal/common/ratelimiter_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRateLimiterConfig(t *testing.T) { config := RateLimiterConfig{} err := config.validate() require.Error(t, err) config.Burst = 1 config.Period = 10 err = config.validate() require.Error(t, err) config.Period = 1000 config.Type = 100 err = config.validate() require.Error(t, err) config.Type = int(rateLimiterTypeSource) config.EntriesSoftLimit = 0 err = config.validate() require.Error(t, err) config.EntriesSoftLimit = 150 config.EntriesHardLimit = 0 err = config.validate() require.Error(t, err) config.EntriesHardLimit = 200 config.Protocols = []string{"unsupported protocol"} err = config.validate() require.Error(t, err) config.Protocols = rateLimiterProtocolValues err = config.validate() require.NoError(t, err) limiter := config.getLimiter() require.Equal(t, 500*time.Millisecond, limiter.maxDelay) require.Nil(t, limiter.globalBucket) config.Type = int(rateLimiterTypeGlobal) config.Average = 1 config.Period = 10000 limiter = config.getLimiter() require.Equal(t, 5*time.Second, limiter.maxDelay) require.NotNil(t, limiter.globalBucket) config.Period = 100000 limiter = config.getLimiter() require.Equal(t, 10*time.Second, limiter.maxDelay) config.Period = 500 config.Average = 1 limiter = config.getLimiter() require.Equal(t, 250*time.Millisecond, limiter.maxDelay) } func TestRateLimiter(t *testing.T) { config := RateLimiterConfig{ Average: 1, Period: 1000, Burst: 1, Type: int(rateLimiterTypeGlobal), Protocols: rateLimiterProtocolValues, } limiter := config.getLimiter() _, err := limiter.Wait("", ProtocolFTP) require.NoError(t, err) _, err = limiter.Wait("", ProtocolSSH) require.Error(t, err) config.Type = int(rateLimiterTypeSource) config.GenerateDefenderEvents = true config.EntriesSoftLimit = 5 config.EntriesHardLimit = 10 limiter = config.getLimiter() source := "192.168.1.2" _, err = limiter.Wait(source, ProtocolSSH) require.NoError(t, err) _, err = limiter.Wait(source, ProtocolSSH) require.Error(t, err) // a different source should work _, err = limiter.Wait(source+"1", ProtocolSSH) require.NoError(t, err) config.Burst = 0 limiter = config.getLimiter() _, err = limiter.Wait(source, ProtocolSSH) require.ErrorIs(t, err, errReserve) } func TestLimiterCleanup(t *testing.T) { config := RateLimiterConfig{ Average: 100, Period: 1000, Burst: 1, Type: int(rateLimiterTypeSource), Protocols: rateLimiterProtocolValues, EntriesSoftLimit: 1, EntriesHardLimit: 3, } limiter := config.getLimiter() source1 := "10.8.0.1" source2 := "10.8.0.2" source3 := "10.8.0.3" source4 := "10.8.0.4" _, err := limiter.Wait(source1, ProtocolSSH) assert.NoError(t, err) time.Sleep(20 * time.Millisecond) _, err = limiter.Wait(source2, ProtocolSSH) assert.NoError(t, err) time.Sleep(20 * time.Millisecond) assert.Len(t, limiter.buckets.buckets, 2) _, ok := limiter.buckets.buckets[source1] assert.True(t, ok) _, ok = limiter.buckets.buckets[source2] assert.True(t, ok) _, err = limiter.Wait(source3, ProtocolSSH) assert.NoError(t, err) assert.Len(t, limiter.buckets.buckets, 3) _, ok = limiter.buckets.buckets[source1] assert.True(t, ok) _, ok = limiter.buckets.buckets[source2] assert.True(t, ok) _, ok = limiter.buckets.buckets[source3] assert.True(t, ok) time.Sleep(20 * time.Millisecond) _, err = limiter.Wait(source4, ProtocolSSH) assert.NoError(t, err) assert.Len(t, limiter.buckets.buckets, 2) _, ok = limiter.buckets.buckets[source3] assert.True(t, ok) _, ok = limiter.buckets.buckets[source4] assert.True(t, ok) } ================================================ FILE: internal/common/tlsutils.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "bytes" "crypto/tls" "crypto/x509" "encoding/pem" "errors" "fmt" "io/fs" "math/rand" "os" "path/filepath" "slices" "sync" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( // DefaultTLSKeyPaidID defines the id to use for non-binding specific key pairs DefaultTLSKeyPaidID = "default" pemCRLType = "X509 CRL" ) var ( pemCRLPrefix = []byte("-----BEGIN X509 CRL") ) // TLSKeyPair defines the paths and the unique identifier for a TLS key pair type TLSKeyPair struct { Cert string Key string ID string } // CertManager defines a TLS certificate manager type CertManager struct { keyPairs []TLSKeyPair configDir string logSender string sync.RWMutex caCertificates []string caRevocationLists []string monitorList []string certs map[string]*tls.Certificate certsInfo map[string]fs.FileInfo rootCAs *x509.CertPool crls []*x509.RevocationList } // Reload tries to reload certificate and CRLs func (m *CertManager) Reload() error { errCrt := m.loadCertificates() errCRLs := m.LoadCRLs() if errCrt != nil { return errCrt } return errCRLs } // LoadCertificates tries to load the configured x509 key pairs func (m *CertManager) loadCertificates() error { if len(m.keyPairs) == 0 { return errors.New("no key pairs defined") } certs := make(map[string]*tls.Certificate) for _, keyPair := range m.keyPairs { if keyPair.ID == "" { return errors.New("TLS certificate without ID") } newCert, err := tls.LoadX509KeyPair(keyPair.Cert, keyPair.Key) if err != nil { logger.Error(m.logSender, "", "unable to load X509 key pair, cert file %q key file %q error: %v", keyPair.Cert, keyPair.Key, err) return err } if _, ok := certs[keyPair.ID]; ok { logger.Error(m.logSender, "", "TLS certificate with id %q is duplicated", keyPair.ID) return fmt.Errorf("TLS certificate with id %q is duplicated", keyPair.ID) } logger.Debug(m.logSender, "", "TLS certificate %q successfully loaded, id %v", keyPair.Cert, keyPair.ID) certs[keyPair.ID] = &newCert if !slices.Contains(m.monitorList, keyPair.Cert) { m.monitorList = append(m.monitorList, keyPair.Cert) } } m.Lock() defer m.Unlock() m.certs = certs return nil } // HasCertificate returns true if there is a certificate for the specified certID func (m *CertManager) HasCertificate(certID string) bool { m.RLock() defer m.RUnlock() _, ok := m.certs[certID] return ok } // GetCertificateFunc returns the loaded certificate func (m *CertManager) GetCertificateFunc(certID string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { m.RLock() defer m.RUnlock() val, ok := m.certs[certID] if !ok { logger.Error(m.logSender, "", "no certificate for id %s", certID) return nil, fmt.Errorf("no certificate for id %s", certID) } return val, nil } } // IsRevoked returns true if the specified certificate has been revoked func (m *CertManager) IsRevoked(crt *x509.Certificate, caCrt *x509.Certificate) bool { m.RLock() defer m.RUnlock() if crt == nil || caCrt == nil { logger.Error(m.logSender, "", "unable to verify crt %v, ca crt %v", crt, caCrt) return len(m.crls) > 0 } for _, crl := range m.crls { if crl.CheckSignatureFrom(caCrt) == nil { for _, rc := range crl.RevokedCertificateEntries { if rc.SerialNumber.Cmp(crt.SerialNumber) == 0 { return true } } } } return false } // LoadCRLs tries to load certificate revocation lists from the given paths func (m *CertManager) LoadCRLs() error { if len(m.caRevocationLists) == 0 { return nil } var crls []*x509.RevocationList for _, revocationList := range m.caRevocationLists { if !util.IsFileInputValid(revocationList) { return fmt.Errorf("invalid root CA revocation list %q", revocationList) } if revocationList != "" && !filepath.IsAbs(revocationList) { revocationList = filepath.Join(m.configDir, revocationList) } crlBytes, err := os.ReadFile(revocationList) if err != nil { logger.Error(m.logSender, "", "unable to read revocation list %q", revocationList) return err } if bytes.HasPrefix(crlBytes, pemCRLPrefix) { block, _ := pem.Decode(crlBytes) if block != nil && block.Type == pemCRLType { crlBytes = block.Bytes } } crl, err := x509.ParseRevocationList(crlBytes) if err != nil { logger.Error(m.logSender, "", "unable to parse revocation list %q", revocationList) return err } logger.Debug(m.logSender, "", "CRL %q successfully loaded", revocationList) crls = append(crls, crl) if !slices.Contains(m.monitorList, revocationList) { m.monitorList = append(m.monitorList, revocationList) } } m.Lock() defer m.Unlock() m.crls = crls return nil } // GetRootCAs returns the set of root certificate authorities that servers // use if required to verify a client certificate func (m *CertManager) GetRootCAs() *x509.CertPool { m.RLock() defer m.RUnlock() return m.rootCAs } // LoadRootCAs tries to load root CA certificate authorities from the given paths func (m *CertManager) LoadRootCAs() error { if len(m.caCertificates) == 0 { return nil } rootCAs := x509.NewCertPool() for _, rootCA := range m.caCertificates { if !util.IsFileInputValid(rootCA) { return fmt.Errorf("invalid root CA certificate %q", rootCA) } if rootCA != "" && !filepath.IsAbs(rootCA) { rootCA = filepath.Join(m.configDir, rootCA) } crt, err := os.ReadFile(rootCA) if err != nil { logger.Error(m.logSender, "", "unable to read root CA from file %q: %v", rootCA, err) return err } if rootCAs.AppendCertsFromPEM(crt) { logger.Debug(m.logSender, "", "TLS certificate authority %q successfully loaded", rootCA) } else { err := fmt.Errorf("unable to load TLS certificate authority %q", rootCA) logger.Error(m.logSender, "", "%v", err) return err } } m.Lock() defer m.Unlock() m.rootCAs = rootCAs return nil } // SetCACertificates sets the root CA authorities file paths. // This should not be changed at runtime func (m *CertManager) SetCACertificates(caCertificates []string) { m.caCertificates = util.RemoveDuplicates(caCertificates, true) } // SetCARevocationLists sets the CA revocation lists file paths. // This should not be changed at runtime func (m *CertManager) SetCARevocationLists(caRevocationLists []string) { m.caRevocationLists = util.RemoveDuplicates(caRevocationLists, true) } func (m *CertManager) monitor() { certsInfo := make(map[string]fs.FileInfo) for _, crt := range m.monitorList { info, err := os.Stat(crt) if err != nil { logger.Warn(m.logSender, "", "unable to stat certificate to monitor %q: %v", crt, err) return } certsInfo[crt] = info } m.Lock() isChanged := false for k, oldInfo := range m.certsInfo { newInfo, ok := certsInfo[k] if ok { if newInfo.Size() != oldInfo.Size() || newInfo.ModTime() != oldInfo.ModTime() { logger.Debug(m.logSender, "", "change detected for certificate %q, reload required", k) isChanged = true } } } m.certsInfo = certsInfo m.Unlock() if isChanged { m.Reload() //nolint:errcheck } } // NewCertManager creates a new certificate manager func NewCertManager(keyPairs []TLSKeyPair, configDir, logSender string) (*CertManager, error) { manager := &CertManager{ keyPairs: keyPairs, configDir: configDir, logSender: logSender, certs: make(map[string]*tls.Certificate), certsInfo: make(map[string]fs.FileInfo), } err := manager.loadCertificates() if err != nil { return nil, err } randSecs := rand.Intn(59) manager.monitor() if eventScheduler != nil { _, err = eventScheduler.AddFunc(fmt.Sprintf("@every 8h0m%ds", randSecs), manager.monitor) } return manager, err } ================================================ FILE: internal/common/tlsutils_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "crypto/tls" "crypto/x509" "os" "path/filepath" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const ( serverCert = `-----BEGIN CERTIFICATE----- MIIEIjCCAgqgAwIBAgIQfxHX0pnvRtkmtfLklgrcNzANBgkqhkiG9w0BAQsFADAT MREwDwYDVQQDEwhDZXJ0QXV0aDAeFw0yMzAxMDMxMDIyMDdaFw0zMzAxMDMxMDMw NDVaMBQxEjAQBgNVBAMTCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEP ADCCAQoCggEBAKbMWjMhyjMnDsq/19J9D44Y13uPSMN26NFOCfjVgV23zcqvI8W1 csosYj89gSmIRxpcL2FtX7NjIT4vaqXob/en1lYy8hstacOs2cy2LcVZHfxu/hv3 6hEKLY28tOD41L1CYZesBt3yV8vGcYIOnnAdIiG52SChnduTafBVE9Pq5P7qJ1gZ d4uBYxe8/Za0metKDvMN6FTK+THq56eD830iRwFOdSw3Z4NS/nQNeVW263E4CC4u BVxgwIHu6giqEfIoV6oVTY64y8X2YlwqvbVN/OtWNIJBLu+mN2EhR2ygpZdAyc82 1yrk/X2/Dd3OiKSrrvXL1fOuNGlLNGD+3vUCAwEAAaNxMG8wDgYDVR0PAQH/BAQD AgO4MB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4EFgQUabrE 6ATHRqEf/CDQiNWI+0e/nhIwHwYDVR0jBBgwFoAUKPyWZxHuWgH3MA/996i3V4gd aYgwDQYJKoZIhvcNAQELBQADggIBAHFtnPXxCCeeGw4RiIai3bavGtyK5qooZUia hN8abJp9VJKYthLwF75c0wn8W0ZMTY8z9xgmFK9afWHCBNyK+0KCpd/LdDUfwvIn 3RwR4HRFjNG+n1UZBA4l1W6X6kCq9/x7YaKLrek9aBHfxwMnoMrOeMUybm6D+B5E lSkAyJRq5VHVatM7UGmdux2MXK5IMpzlIBzz1pXddnzF3f9nfS54xt6ilWst9bMi 6mBxisJmqc51L/Fyb2SoCJoO/6kv+3V5HnRNBcZuVE8G5/Uc+WRnyy9dh996W83b jNvSJ9UpspqMtKx7DKU4fC/3xYDjRimZvZ3akfIdkf3j5GVWMtVbx+QVSZ8aKBSM Zx35p8aF0zppTjp2JvBpiQlGIXKfPkmmH4bLpU7Z7qLXFFnp+fs3CjcIng19gGgi XQldgHVsl8FtIebxgW6wc5jb2y/fXjgx9c0SKEeeA3Pp6fExH8PdQdyHHmkHKQzO ozon1tZhQbcjkNz8kXFp3x3X/0i4TsR6vsUigSFHXT7DgusBK8eAiRVOLSpbfIyp 7Ul/9DjhtYxcZjNI/xNJcECPGazNDdKh4TdLh35pnQHOsRXDWB873rr5xkJIUXbU ubo+q0VpmF7OtfPO9PrPilWAUhVDRx7CCTW3YUsWrYJkr8d6F/n6y7QPKMtB9Y2P jRJ4LDqX -----END CERTIFICATE-----` serverKey = `-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEApsxaMyHKMycOyr/X0n0PjhjXe49Iw3bo0U4J+NWBXbfNyq8j xbVyyixiPz2BKYhHGlwvYW1fs2MhPi9qpehv96fWVjLyGy1pw6zZzLYtxVkd/G7+ G/fqEQotjby04PjUvUJhl6wG3fJXy8Zxgg6ecB0iIbnZIKGd25Np8FUT0+rk/uon WBl3i4FjF7z9lrSZ60oO8w3oVMr5Mernp4PzfSJHAU51LDdng1L+dA15VbbrcTgI Li4FXGDAge7qCKoR8ihXqhVNjrjLxfZiXCq9tU3861Y0gkEu76Y3YSFHbKCll0DJ zzbXKuT9fb8N3c6IpKuu9cvV8640aUs0YP7e9QIDAQABAoIBADbD9gG/4HH3KwYr AyPbaBYR1f59xzhWfI7sfp2zDGzHAsy/wJETyILVG9UDzrriQeZHyk7E6J0vuSR/ 0RZ0QP8hnmBjDdcajBVxVXm/fzvCzPOrRcfNGI9LtjVJdmI/kSoq93wjQYXyIh2I JJC9WAwbpK9KJB5wsjH8LtZ4OLBlcdeB8jcvO6FzGij6HwyxqyPctxetlvpcmc/w zNJhps6t+TJ8PpNtEmTpOOmx85V6HMb3QJexwmUYygRaOoiQKBKZSNaOnGoC8w1d WahyyXJk4B3OUllqG1TLUgabFGqq2PeJSP8RvYFH8DUj+fdxD78qDHAygrL8ELLZ 2O3Wi0ECgYEAyREnS/kylyIcAsyKczsKEDMIDUF9rGvm2B+QG7cLKHTu24oiNg5B Ik5nkaYmSSrC3O2/s4v47mYzMtWbLxlogiNK6ljLPpdU5/JaeHncZC+18seBoePQ 9nOW3AvY2A6ihzy8sKRMfl3FUx/1rcXLdNwkMQo0FWR7nqVPUme9QkkCgYEA1F5n lhfDptiHekagKMTf9SGw4B2UiG6SLjMWhcG2AEFeXpZlsk7Qubnuzk0krjYp+JAI brlzMOkmBXBQywKLe3SG0s0McbRGWVFbEA1SA+WZV5rwJe5PO7W6ndCF2+slyZ5T dPwOY1RybV6R07EvjtfnE8Wtdyko4X22sTkyd00CgYA5MYnuEHqVhvxUx33yfS7F oN5/dsuayi6l94R0fcLMxUZUaJyGp9NbQNYxFgP5+BHp6i8HkZ9DoQqbQSudYCrc KdHbi1p0+XMLb2LQtkk8rl2hK6LyO+1qzUJyYWRTQQZ2VY6O6I1hvKaumH636XWQ TjZ1RKPAGg8X94nytNOfEQKBgQC/+TL0iDjyGyykyTFAiW/WXQVSIwtBJYr5Pm9u rESFCJJxOM1nmT2vlrecQDoXTZk1O6aTyQqrPSeEpRoz2fISwKyb5IYKRyeM2DFU WmY4ZZXvjnzmHP39APNYc8Z9nZzEHF5fEvdCrXTfDy0Ny08tdlhKFFkRreBprkW3 APhwxQKBgDBdionnjdB9jdGbYHrsPaweMGdQNXkrTTCFfBA47F+qZswfon12yu4A +cBKCnQe2dQHl8AV3IeUKpmNghu4iICOASQEO9dS6OWZI5vBxZMePBm6+bjTOuf6 ozecw3yR55tKpPImt87rhrWlwp35uWuhOr9GHYBdFSwgrEkVMw++ -----END RSA PRIVATE KEY-----` caCRT = `-----BEGIN CERTIFICATE----- MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U 4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY /tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w kK8u1YiiVenmoQ== -----END CERTIFICATE-----` caKey = `-----BEGIN RSA PRIVATE KEY----- MIIJKgIBAAKCAgEA7WHW216mfi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7x b64rkpdzx1aWetSiCrEyc3D1v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvD OBUYZgtMqHZzpE6xRrqQ84zhyzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKz n/2uEVt33qmO85WtN3RzbSqLCdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj 7B5P5MeamkkogwbExUjdHp3U4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZ De67V/Q8iB2May1k7zBz1ZtbKF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmOb cn8AIfH6smLQrn0C3cs7CYfoNlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLq R/BJTbbyXUB0imne1u00fuzbS7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb7 8x7ivdyXSF5LVQJ1JvhhWu6iM6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpB P8d2jcRZVUVrSXGc2mAGuGOY/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2Z xugCXULtRWJ9p4C9zUl40HEyOQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEA AQKCAgEA4x0OoceG54ZrVxifqVaQd8qw3uRmUKUMIMdfuMlsdideeLO97ynmSlRY 00kGo/I4Lp6mNEjI9gUie9+uBrcUhri4YLcujHCH+YlNnCBDbGjwbe0ds9SLCWaa KztZHMSlW5Q4Bqytgu+MpOnxSgqjlOk+vz9TcGFKVnUkHIkAcqKFJX8gOFxPZA/t Ob1kJaz4kuv5W2Kur/ISKvQtvFvOtQeV0aJyZm8LqXnvS4cPI7yN4329NDU0HyDR y/deqS2aqV4zII3FFqbz8zix/m1xtVQzWCugZGMKrz0iuJMfNeCABb8rRGc6GsZz +465v/kobqgeyyneJ1s5rMFrLp2o+dwmnIVMNsFDUiN1lIZDHLvlgonaUO3IdTZc 9asamFWKFKUMgWqM4zB1vmUO12CKowLNIIKb0L+kf1ixaLLDRGf/f9vLtSHE+oyx lATiS18VNA8+CGsHF6uXMRwf2auZdRI9+s6AAeyRISSbO1khyWKHo+bpOvmPAkDR nknTjbYgkoZOV+mrsU5oxV8s6vMkuvA3rwFhT2gie8pokuACFcCRrZi9MVs4LmUQ u0GYTHvp2WJUjMWBm6XX7Hk3g2HV842qpk/mdtTjNsXws81djtJPn4I/soIXSgXz pY3SvKTuOckP9OZVF0yqKGeZXKpD288PKpC+MAg3GvEJaednagECggEBAPsfLwuP L1kiDjXyMcRoKlrQ6Q/zBGyBmJbZ5uVGa02+XtYtDAzLoVupPESXL0E7+r8ZpZ39 0dV4CEJKpbVS/BBtTEkPpTK5kz778Ib04TAyj+YLhsZjsnuja3T5bIBZXFDeDVDM 0ZaoFoKpIjTu2aO6pzngsgXs6EYbo2MTuJD3h0nkGZsICL7xvT9Mw0P1p2Ftt/hN +jKk3vN220wTWUsq43AePi45VwK+PNP12ZXv9HpWDxlPo3j0nXtgYXittYNAT92u BZbFAzldEIX9WKKZgsWtIzLaASjVRntpxDCTby/nlzQ5dw3DHU1DV3PIqxZS2+Oe KV+7XFWgZ44YjYECggEBAPH+VDu3QSrqSahkZLkgBtGRkiZPkZFXYvU6kL8qf5wO Z/uXMeqHtznAupLea8I4YZLfQim/NfC0v1cAcFa9Ckt9g3GwTSirVcN0AC1iOyv3 /hMZCA1zIyIcuUplNr8qewoX71uPOvCNH0dix77423mKFkJmNwzy4Q+rV+qkRdLn v+AAgh7g5N91pxNd6LQJjoyfi1Ka6rRP2yGXM5v7QOwD16eN4JmExUxX1YQ7uNuX pVS+HRxnBquA+3/DB1LtBX6pa2cUa+LRUmE/NCPHMvJcyuNkYpJKlNTd9vnbfo0H RNSJSWm+aGxDFMjuPjV3JLj2OdKMPwpnXdh2vBZCPpMCggEAM+yTvrEhmi2HgLIO hkz/jP2rYyfdn04ArhhqPLgd0dpuI5z24+Jq/9fzZT9ZfwSW6VK1QwDLlXcXRhXH Q8Hf6smev3CjuORURO61IkKaGWwrAucZPAY7ToNQ4cP9ImDXzMTNPgrLv3oMBYJR V16X09nxX+9NABqnQG/QjdjzDc6Qw7+NZ9f2bvzvI5qMuY2eyW91XbtJ45ThoLfP ymAp03gPxQwL0WT7z85kJ3OrROxzwaPvxU0JQSZbNbqNDPXmFTiECxNDhpRAAWlz 1DC5Vg2l05fkMkyPdtD6nOQWs/CYSfB5/EtxiX/xnBszhvZUIe6KFvuKFIhaJD5h iykagQKCAQEAoBRm8k3KbTIo4ZzvyEq4V/+dF3zBRczx6FkCkYLygXBCNvsQiR2Y BjtI8Ijz7bnQShEoOmeDriRTAqGGrspEuiVgQ1+l2wZkKHRe/aaij/Zv+4AuhH8q uZEYvW7w5Uqbs9SbgQzhp2kjTNy6V8lVnjPLf8cQGZ+9Y9krwktC6T5m/i435WdN 38h7amNP4XEE/F86Eb3rDrZYtgLIoCF4E+iCyxMehU+AGH1uABhls9XAB6vvo+8/ SUp8lEqWWLP0U5KNOtYWfCeOAEiIHDbUq+DYUc4BKtbtV1cx3pzlPTOWw6XBi5Lq jttdL4HyYvnasAQpwe8GcMJqIRyCVZMiwwKCAQEAhQTTS3CC8PwcoYrpBdTjW1ck vVFeF1YbfqPZfYxASCOtdx6wRnnEJ+bjqntagns9e88muxj9UhxSL6q9XaXQBD8+ 2AmKUxphCZQiYFZcTucjQEQEI2nN+nAKgRrUSMMGiR8Ekc2iFrcxBU0dnSohw+aB PbMKVypQCREu9PcDFIp9rXQTeElbaNsIg1C1w/SQjODbmN/QFHTVbRODYqLeX1J/ VcGsykSIq7hv6bjn7JGkr2JTdANbjk9LnMjMdJFsKRYxPKkOQfYred6Hiojp5Sor PW5am8ejnNSPhIfqQp3uV3KhwPDKIeIpzvrB4uPfTjQWhekHCb8cKSWux3flqw== -----END RSA PRIVATE KEY-----` caCRL = `-----BEGIN X509 CRL----- MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV +jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn w7bXJGfJcIMrrKs= -----END X509 CRL-----` client1Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au MU3Bo0A= -----END CERTIFICATE-----` client1Key = `-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi 85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== -----END RSA PRIVATE KEY-----` // client 2 crt is revoked client2Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB 4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO 78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv /ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR E0+r2+4= -----END CERTIFICATE-----` client2Key = `-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu 4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl /owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz 3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D 8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl -----END RSA PRIVATE KEY-----` ) func TestLoadCertificate(t *testing.T) { startEventScheduler() caCrtPath := filepath.Join(os.TempDir(), "testca.crt") caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") certPath := filepath.Join(os.TempDir(), "test.crt") keyPath := filepath.Join(os.TempDir(), "test.key") err := os.WriteFile(caCrtPath, []byte(caCRT), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(certPath, []byte(serverCert), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(keyPath, []byte(serverKey), os.ModePerm) assert.NoError(t, err) keyPairs := []TLSKeyPair{ { Cert: certPath, Key: keyPath, ID: DefaultTLSKeyPaidID, }, { Cert: certPath, Key: keyPath, ID: DefaultTLSKeyPaidID, }, } certManager, err := NewCertManager(keyPairs, configDir, logSenderTest) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is duplicated") } assert.Nil(t, certManager) keyPairs = []TLSKeyPair{ { Cert: certPath, Key: keyPath, ID: DefaultTLSKeyPaidID, }, } certManager, err = NewCertManager(keyPairs, configDir, logSenderTest) assert.NoError(t, err) assert.True(t, certManager.HasCertificate(DefaultTLSKeyPaidID)) assert.False(t, certManager.HasCertificate("unknownID")) certFunc := certManager.GetCertificateFunc(DefaultTLSKeyPaidID) if assert.NotNil(t, certFunc) { hello := &tls.ClientHelloInfo{ ServerName: "localhost", CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}, } cert, err := certFunc(hello) assert.NoError(t, err) assert.Equal(t, certManager.certs[DefaultTLSKeyPaidID], cert) } certFunc = certManager.GetCertificateFunc("unknownID") if assert.NotNil(t, certFunc) { hello := &tls.ClientHelloInfo{ ServerName: "localhost", CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}, } _, err = certFunc(hello) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no certificate for id unknownID") } } certManager.SetCACertificates(nil) err = certManager.LoadRootCAs() assert.NoError(t, err) certManager.SetCACertificates([]string{""}) err = certManager.LoadRootCAs() assert.Error(t, err) certManager.SetCACertificates([]string{"invalid"}) err = certManager.LoadRootCAs() assert.Error(t, err) // laoding the key as root CA must fail certManager.SetCACertificates([]string{keyPath}) err = certManager.LoadRootCAs() assert.Error(t, err) certManager.SetCACertificates([]string{certPath}) err = certManager.LoadRootCAs() assert.NoError(t, err) rootCa := certManager.GetRootCAs() assert.NotNil(t, rootCa) err = certManager.Reload() assert.NoError(t, err) certManager.SetCARevocationLists(nil) err = certManager.LoadCRLs() assert.NoError(t, err) certManager.SetCARevocationLists([]string{""}) err = certManager.LoadCRLs() assert.Error(t, err) certManager.SetCARevocationLists([]string{"invalid crl"}) err = certManager.LoadCRLs() assert.Error(t, err) // this is not a crl and must fail certManager.SetCARevocationLists([]string{caCrtPath}) err = certManager.LoadCRLs() assert.Error(t, err) certManager.SetCARevocationLists([]string{caCrlPath}) err = certManager.LoadCRLs() assert.NoError(t, err) crt, err := tls.X509KeyPair([]byte(caCRT), []byte(caKey)) assert.NoError(t, err) x509CAcrt, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) crt, err = tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) x509crt, err := x509.ParseCertificate(crt.Certificate[0]) if assert.NoError(t, err) { assert.False(t, certManager.IsRevoked(x509crt, x509CAcrt)) } crt, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) assert.NoError(t, err) x509crt, err = x509.ParseCertificate(crt.Certificate[0]) if assert.NoError(t, err) { assert.True(t, certManager.IsRevoked(x509crt, x509CAcrt)) } assert.True(t, certManager.IsRevoked(nil, nil)) err = os.Remove(caCrlPath) assert.NoError(t, err) err = certManager.Reload() assert.Error(t, err) err = os.Remove(certPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) err = certManager.Reload() assert.Error(t, err) err = os.Remove(caCrtPath) assert.NoError(t, err) stopEventScheduler() } func TestLoadInvalidCert(t *testing.T) { startEventScheduler() certManager, err := NewCertManager(nil, configDir, logSenderTest) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no key pairs defined") } assert.Nil(t, certManager) keyPairs := []TLSKeyPair{ { Cert: "test.crt", Key: "test.key", ID: DefaultTLSKeyPaidID, }, } certManager, err = NewCertManager(keyPairs, configDir, logSenderTest) assert.Error(t, err) assert.Nil(t, certManager) keyPairs = []TLSKeyPair{ { Cert: "test.crt", Key: "test.key", }, } certManager, err = NewCertManager(keyPairs, configDir, logSenderTest) if assert.Error(t, err) { assert.Contains(t, err.Error(), "TLS certificate without ID") } assert.Nil(t, certManager) stopEventScheduler() } func TestCertificateMonitor(t *testing.T) { startEventScheduler() defer stopEventScheduler() certPath := filepath.Join(os.TempDir(), "test.crt") keyPath := filepath.Join(os.TempDir(), "test.key") caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") err := os.WriteFile(certPath, []byte(serverCert), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(keyPath, []byte(serverKey), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) assert.NoError(t, err) keyPairs := []TLSKeyPair{ { Cert: certPath, Key: keyPath, ID: DefaultTLSKeyPaidID, }, } certManager, err := NewCertManager(keyPairs, configDir, logSenderTest) assert.NoError(t, err) assert.Len(t, certManager.monitorList, 1) require.Len(t, certManager.certsInfo, 1) info := certManager.certsInfo[certPath] require.NotNil(t, info) certManager.SetCARevocationLists([]string{caCrlPath}) err = certManager.LoadCRLs() assert.NoError(t, err) assert.Len(t, certManager.monitorList, 2) certManager.monitor() require.Len(t, certManager.certsInfo, 2) err = os.Remove(certPath) assert.NoError(t, err) certManager.monitor() time.Sleep(100 * time.Millisecond) err = os.WriteFile(certPath, []byte(serverCert), os.ModePerm) assert.NoError(t, err) certManager.monitor() require.Len(t, certManager.certsInfo, 2) newInfo := certManager.certsInfo[certPath] require.NotNil(t, newInfo) assert.Equal(t, info.Size(), newInfo.Size()) assert.NotEqual(t, info.ModTime(), newInfo.ModTime()) err = os.Remove(caCrlPath) assert.NoError(t, err) err = os.Remove(certPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) } ================================================ FILE: internal/common/transfer.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "errors" "fmt" "io/fs" "path" "sync" "sync/atomic" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/vfs" ) var ( // ErrTransferClosed defines the error returned for a closed transfer ErrTransferClosed = errors.New("transfer already closed") ) // BaseTransfer contains protocols common transfer details for an upload or a download. type BaseTransfer struct { ID int64 BytesSent atomic.Int64 BytesReceived atomic.Int64 Fs vfs.Fs File vfs.File Connection *BaseConnection cancelFn func() fsPath string effectiveFsPath string requestPath string ftpMode string start time.Time MaxWriteSize int64 MinWriteOffset int64 InitialSize int64 truncatedSize int64 isNewFile bool transferType int AbortTransfer atomic.Bool aTime time.Time mTime time.Time transferQuota dataprovider.TransferQuota metadata map[string]string sync.Mutex errAbort error ErrTransfer error } // NewBaseTransfer returns a new BaseTransfer and adds it to the given connection func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string, transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs, transferQuota dataprovider.TransferQuota, ) *BaseTransfer { t := &BaseTransfer{ ID: conn.GetTransferID(), File: file, Connection: conn, cancelFn: cancelFn, fsPath: fsPath, effectiveFsPath: effectiveFsPath, start: time.Now(), transferType: transferType, MinWriteOffset: minWriteOffset, InitialSize: initialSize, isNewFile: isNewFile, requestPath: requestPath, MaxWriteSize: maxWriteSize, truncatedSize: truncatedSize, transferQuota: transferQuota, Fs: fs, } t.AbortTransfer.Store(false) t.BytesSent.Store(0) t.BytesReceived.Store(0) conn.AddTransfer(t) return t } // GetTransferQuota returns data transfer quota limits func (t *BaseTransfer) GetTransferQuota() dataprovider.TransferQuota { return t.transferQuota } // SetFtpMode sets the FTP mode for the current transfer func (t *BaseTransfer) SetFtpMode(mode string) { t.ftpMode = mode } // GetID returns the transfer ID func (t *BaseTransfer) GetID() int64 { return t.ID } // GetType returns the transfer type func (t *BaseTransfer) GetType() int { return t.transferType } // GetSize returns the transferred size func (t *BaseTransfer) GetSize() int64 { if t.transferType == TransferDownload { return t.BytesSent.Load() } return t.BytesReceived.Load() } // GetDownloadedSize returns the transferred size func (t *BaseTransfer) GetDownloadedSize() int64 { return t.BytesSent.Load() } // GetUploadedSize returns the transferred size func (t *BaseTransfer) GetUploadedSize() int64 { return t.BytesReceived.Load() } // GetStartTime returns the start time func (t *BaseTransfer) GetStartTime() time.Time { return t.start } // GetAbortError returns the error to send to the client if the transfer was aborted func (t *BaseTransfer) GetAbortError() error { t.Lock() defer t.Unlock() if t.errAbort != nil { return t.errAbort } return getQuotaExceededError(t.Connection.protocol) } // SignalClose signals that the transfer should be closed after the next read/write. // The optional error argument allow to send a specific error, otherwise a generic // transfer aborted error is sent func (t *BaseTransfer) SignalClose(err error) { t.Lock() t.errAbort = err t.Unlock() t.AbortTransfer.Store(true) } // GetTruncatedSize returns the truncated sized if this is an upload overwriting // an existing file func (t *BaseTransfer) GetTruncatedSize() int64 { return t.truncatedSize } // HasSizeLimit returns true if there is an upload or download size limit func (t *BaseTransfer) HasSizeLimit() bool { if t.MaxWriteSize > 0 { return true } if t.transferQuota.HasSizeLimits() { return true } return false } // GetVirtualPath returns the transfer virtual path func (t *BaseTransfer) GetVirtualPath() string { return t.requestPath } // GetFsPath returns the transfer filesystem path func (t *BaseTransfer) GetFsPath() string { return t.fsPath } // SetTimes stores access and modification times if fsPath matches the current file func (t *BaseTransfer) SetTimes(fsPath string, atime time.Time, mtime time.Time) bool { if fsPath == t.GetFsPath() { t.aTime = atime t.mTime = mtime return true } return false } // GetRealFsPath returns the real transfer filesystem path. // If atomic uploads are enabled this differ from fsPath func (t *BaseTransfer) GetRealFsPath(fsPath string) string { if fsPath == t.GetFsPath() { if t.File != nil || vfs.IsLocalOsFs(t.Fs) { return t.effectiveFsPath } return t.fsPath } return "" } // SetMetadata sets the metadata for the file func (t *BaseTransfer) SetMetadata(val map[string]string) { t.metadata = val } // SetCancelFn sets the cancel function for the transfer func (t *BaseTransfer) SetCancelFn(cancelFn func()) { t.cancelFn = cancelFn } // ConvertError accepts an error that occurs during a read or write and // converts it into a more understandable form for the client if it is a // well-known type of error func (t *BaseTransfer) ConvertError(err error) error { var pathError *fs.PathError if errors.As(err, &pathError) { return fmt.Errorf("%s %s: %s", pathError.Op, t.GetVirtualPath(), pathError.Err.Error()) } return t.Connection.GetFsError(t.Fs, err) } // CheckRead returns an error if read if not allowed func (t *BaseTransfer) CheckRead() error { if t.transferQuota.AllowedDLSize == 0 && t.transferQuota.AllowedTotalSize == 0 { return nil } if t.transferQuota.AllowedTotalSize > 0 { if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize { return t.Connection.GetReadQuotaExceededError() } } else if t.transferQuota.AllowedDLSize > 0 { if t.BytesSent.Load() > t.transferQuota.AllowedDLSize { return t.Connection.GetReadQuotaExceededError() } } return nil } // CheckWrite returns an error if write if not allowed func (t *BaseTransfer) CheckWrite() error { if t.MaxWriteSize > 0 && t.BytesReceived.Load() > t.MaxWriteSize { return t.Connection.GetQuotaExceededError() } if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 { return nil } if t.transferQuota.AllowedTotalSize > 0 { if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize { return t.Connection.GetQuotaExceededError() } } else if t.transferQuota.AllowedULSize > 0 { if t.BytesReceived.Load() > t.transferQuota.AllowedULSize { return t.Connection.GetQuotaExceededError() } } return nil } // Truncate changes the size of the opened file. // Supported for local fs only func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) { if fsPath == t.GetFsPath() { if t.File != nil { initialSize := t.InitialSize err := t.File.Truncate(size) if err == nil { t.Lock() t.InitialSize = size if t.MaxWriteSize > 0 { sizeDiff := initialSize - size t.MaxWriteSize += sizeDiff metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs)) if t.transferQuota.HasSizeLimits() { go func(ulSize, dlSize int64, user dataprovider.User) { dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck }(t.BytesReceived.Load(), t.BytesSent.Load(), t.Connection.User) } t.BytesReceived.Store(0) } t.Unlock() } t.Connection.Log(logger.LevelDebug, "file %q truncated to size %v max write size %v new initial size %v err: %v", fsPath, size, t.MaxWriteSize, t.InitialSize, err) return initialSize, err } if size == 0 && t.BytesSent.Load() == 0 { // for cloud providers the file is always truncated to zero, we don't support append/resume for uploads. // For buffered SFTP and local fs we can have buffered bytes so we returns an error if !vfs.IsBufferedLocalOrSFTPFs(t.Fs) { return 0, nil } } return 0, vfs.ErrVfsUnsupported } return 0, errTransferMismatch } // TransferError is called if there is an unexpected error. // For example network or client issues func (t *BaseTransfer) TransferError(err error) { t.Lock() defer t.Unlock() if t.ErrTransfer != nil { return } t.ErrTransfer = err if t.cancelFn != nil { t.cancelFn() } elapsed := time.Since(t.start).Nanoseconds() / 1000000 t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %q, error: \"%v\" bytes sent: %v, "+ "bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, t.BytesSent.Load(), t.BytesReceived.Load(), elapsed) } func (t *BaseTransfer) getUploadFileSize() (int64, int, error) { var fileSize int64 var deletedFiles int switch dataprovider.GetQuotaTracking() { case 0: return fileSize, deletedFiles, errors.New("quota tracking disabled") case 2: if !t.Connection.User.HasQuotaRestrictions() { vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath)) if err != nil { return fileSize, deletedFiles, errors.New("quota tracking disabled for this user") } if vfolder.IsIncludedInUserQuota() { return fileSize, deletedFiles, errors.New("quota tracking disabled for this user and folder included in user quota") } } } info, err := t.Fs.Stat(t.fsPath) if err == nil { fileSize = info.Size() } if t.ErrTransfer != nil && vfs.IsCryptOsFs(t.Fs) { errDelete := t.Fs.Remove(t.fsPath, false) if errDelete != nil { t.Connection.Log(logger.LevelWarn, "error removing partial crypto file %q: %v", t.fsPath, errDelete) } else { fileSize = 0 deletedFiles = 1 t.BytesReceived.Store(0) t.MinWriteOffset = 0 } } return fileSize, deletedFiles, err } // return 1 if the file is outside the user home dir func (t *BaseTransfer) checkUploadOutsideHomeDir(err error) int { if err == nil { return 0 } if t.ErrTransfer == nil { t.ErrTransfer = err } if Config.TempPath == "" { return 0 } err = t.Fs.Remove(t.effectiveFsPath, false) t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %q, deletion error: %v", t.effectiveFsPath, err) // the file is outside the home dir so don't update the quota t.BytesReceived.Store(0) t.MinWriteOffset = 0 return 1 } // Close it is called when the transfer is completed. // It logs the transfer info, updates the user quota (for uploads) // and executes any defined action. // If there is an error no action will be executed and, in atomic mode, // we try to delete the temporary file func (t *BaseTransfer) Close() error { defer t.Connection.RemoveTransfer(t) var err error numFiles := t.getUploadedFiles() metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs)) if t.transferQuota.HasSizeLimits() { dataprovider.UpdateUserTransferQuota(&t.Connection.User, t.BytesReceived.Load(), //nolint:errcheck t.BytesSent.Load(), false) } if (t.File != nil || vfs.IsLocalOsFs(t.Fs)) && t.Connection.IsQuotaExceededError(t.ErrTransfer) { // if quota is exceeded we try to remove the partial file for uploads to local filesystem err = t.Fs.Remove(t.effectiveFsPath, false) if err == nil { t.BytesReceived.Store(0) t.MinWriteOffset = 0 } t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %q, deletion error: %v", t.effectiveFsPath, err) } else if t.isAtomicUpload() { if t.ErrTransfer == nil || Config.UploadMode&UploadModeAtomicWithResume != 0 { _, _, err = t.Fs.Rename(t.effectiveFsPath, t.fsPath, 0) t.Connection.Log(logger.LevelDebug, "atomic upload completed, rename: %q -> %q, error: %v", t.effectiveFsPath, t.fsPath, err) // the file must be removed if it is uploaded to a path outside the home dir and cannot be renamed t.checkUploadOutsideHomeDir(err) } else { err = t.Fs.Remove(t.effectiveFsPath, false) t.Connection.Log(logger.LevelWarn, "atomic upload completed with error: \"%v\", delete temporary file: %q, deletion error: %v", t.ErrTransfer, t.effectiveFsPath, err) if err == nil { t.BytesReceived.Store(0) t.MinWriteOffset = 0 } } } elapsed := time.Since(t.start).Nanoseconds() / 1000000 var uploadFileSize int64 if t.transferType == TransferDownload { logger.TransferLog(downloadLogSender, t.fsPath, elapsed, t.BytesSent.Load(), t.Connection.User.Username, t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode, t.ErrTransfer) ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "", //nolint:errcheck t.BytesSent.Load(), t.ErrTransfer, elapsed, t.metadata) } else { statSize, deletedFiles, errStat := t.getUploadFileSize() if errStat == nil { uploadFileSize = statSize } else { uploadFileSize = t.BytesReceived.Load() + t.MinWriteOffset if t.Fs.IsNotExist(errStat) { uploadFileSize = 0 numFiles-- } } numFiles -= deletedFiles t.Connection.Log(logger.LevelDebug, "upload file size %d, num files %d, deleted files %d, fs path %q", uploadFileSize, numFiles, deletedFiles, t.fsPath) numFiles, uploadFileSize = t.executeUploadHook(numFiles, uploadFileSize, elapsed) t.updateQuota(numFiles, uploadFileSize) t.updateTimes() logger.TransferLog(uploadLogSender, t.fsPath, elapsed, t.BytesReceived.Load(), t.Connection.User.Username, t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode, t.ErrTransfer) } if t.ErrTransfer != nil { t.Connection.Log(logger.LevelError, "transfer error: %v, path: %q", t.ErrTransfer, t.fsPath) if err == nil { err = t.ErrTransfer } } t.updateTransferTimestamps(uploadFileSize, elapsed) return err } func (t *BaseTransfer) isAtomicUpload() bool { return t.transferType == TransferUpload && t.effectiveFsPath != t.fsPath } func (t *BaseTransfer) updateTransferTimestamps(uploadFileSize, elapsed int64) { if t.ErrTransfer != nil { return } if t.transferType == TransferUpload { if t.Connection.User.FirstUpload == 0 && !t.Connection.uploadDone.Load() { if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, true); err == nil { t.Connection.uploadDone.Store(true) ExecuteActionNotification(t.Connection, operationFirstUpload, t.fsPath, t.requestPath, "", //nolint:errcheck "", "", uploadFileSize, t.ErrTransfer, elapsed, t.metadata) } } return } if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && t.BytesSent.Load() > 0 { if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, false); err == nil { t.Connection.downloadDone.Store(true) ExecuteActionNotification(t.Connection, operationFirstDownload, t.fsPath, t.requestPath, "", //nolint:errcheck "", "", t.BytesSent.Load(), t.ErrTransfer, elapsed, t.metadata) } } } func (t *BaseTransfer) executeUploadHook(numFiles int, fileSize, elapsed int64) (int, int64) { err := ExecuteActionNotification(t.Connection, operationUpload, t.fsPath, t.requestPath, "", "", "", fileSize, t.ErrTransfer, elapsed, t.metadata) if err != nil { if t.ErrTransfer == nil { t.ErrTransfer = err } // try to remove the uploaded file err = t.Fs.Remove(t.fsPath, false) if err == nil { numFiles-- fileSize = 0 t.BytesReceived.Store(0) t.MinWriteOffset = 0 } else { t.Connection.Log(logger.LevelWarn, "unable to remove path %q after upload hook failure: %v", t.fsPath, err) } } return numFiles, fileSize } func (t *BaseTransfer) getUploadedFiles() int { numFiles := 0 if t.isNewFile { numFiles = 1 } return numFiles } func (t *BaseTransfer) updateTimes() { if !t.aTime.IsZero() && !t.mTime.IsZero() { err := t.Fs.Chtimes(t.fsPath, t.aTime, t.mTime, false) t.Connection.Log(logger.LevelDebug, "set times for file %q, atime: %v, mtime: %v, err: %v", t.fsPath, t.aTime, t.mTime, err) } } func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool { // Uploads on some filesystem (S3 and similar) are atomic, if there is an error nothing is uploaded if t.File == nil && t.ErrTransfer != nil && vfs.HasImplicitAtomicUploads(t.Fs) { return false } sizeDiff := fileSize - t.InitialSize if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff != 0) { vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath)) if err == nil { dataprovider.UpdateUserFolderQuota(&vfolder, &t.Connection.User, numFiles, sizeDiff, false) } else { dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck } return true } return false } // HandleThrottle manage bandwidth throttling func (t *BaseTransfer) HandleThrottle() { var wantedBandwidth int64 var trasferredBytes int64 if t.transferType == TransferDownload { wantedBandwidth = t.Connection.User.DownloadBandwidth trasferredBytes = t.BytesSent.Load() } else { wantedBandwidth = t.Connection.User.UploadBandwidth trasferredBytes = t.BytesReceived.Load() } if wantedBandwidth > 0 { // real and wanted elapsed as milliseconds, bytes as kilobytes realElapsed := time.Since(t.start).Nanoseconds() / 1000000 // trasferredBytes / 1024 = KB/s, we multiply for 1000 to get milliseconds wantedElapsed := 1000 * (trasferredBytes / 1024) / wantedBandwidth if wantedElapsed > realElapsed { toSleep := time.Duration(wantedElapsed - realElapsed) time.Sleep(toSleep * time.Millisecond) } } } ================================================ FILE: internal/common/transfer_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "errors" "fmt" "os" "path/filepath" "testing" "time" "github.com/pkg/sftp" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/vfs" ) func TestTransferUpdateQuota(t *testing.T) { conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{}) transfer := BaseTransfer{ Connection: conn, transferType: TransferUpload, Fs: vfs.NewOsFs("", os.TempDir(), "", nil), } transfer.BytesReceived.Store(123) errFake := errors.New("fake error") transfer.TransferError(errFake) err := transfer.Close() if assert.Error(t, err) { assert.EqualError(t, err, errFake.Error()) } mappedPath := filepath.Join(os.TempDir(), "vdir") vdirPath := "/vdir" conn.User.VirtualFolders = append(conn.User.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath, }, VirtualPath: vdirPath, QuotaFiles: -1, QuotaSize: -1, }) transfer.ErrTransfer = nil transfer.BytesReceived.Store(1) transfer.requestPath = "/vdir/file" assert.True(t, transfer.updateQuota(1, 0)) err = transfer.Close() assert.NoError(t, err) transfer.ErrTransfer = errFake transfer.Fs = newMockOsFs(true, "", "", "S3Fs fake", nil) assert.False(t, transfer.updateQuota(1, 0)) } func TestTransferThrottling(t *testing.T) { u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test", UploadBandwidth: 50, DownloadBandwidth: 40, }, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) testFileSize := int64(131072) wantedUploadElapsed := 1000 * (testFileSize / 1024) / u.UploadBandwidth wantedDownloadElapsed := 1000 * (testFileSize / 1024) / u.DownloadBandwidth // some tolerance wantedUploadElapsed -= wantedDownloadElapsed / 10 wantedDownloadElapsed -= wantedDownloadElapsed / 10 conn := NewBaseConnection("id", ProtocolSCP, "", "", u) transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer.BytesReceived.Store(testFileSize) transfer.Connection.UpdateLastActivity() startTime := transfer.Connection.GetLastActivity() transfer.HandleThrottle() elapsed := time.Since(startTime).Nanoseconds() / 1000000 assert.GreaterOrEqual(t, elapsed, wantedUploadElapsed, "upload bandwidth throttling not respected") err := transfer.Close() assert.NoError(t, err) transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer.BytesSent.Store(testFileSize) transfer.Connection.UpdateLastActivity() startTime = transfer.Connection.GetLastActivity() transfer.HandleThrottle() elapsed = time.Since(startTime).Nanoseconds() / 1000000 assert.GreaterOrEqual(t, elapsed, wantedDownloadElapsed, "download bandwidth throttling not respected") err = transfer.Close() assert.NoError(t, err) } func TestRealPath(t *testing.T) { testFile := filepath.Join(os.TempDir(), "afile.txt") fs := vfs.NewOsFs("123", os.TempDir(), "", nil) u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user", HomeDir: os.TempDir(), }, } u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} file, err := os.Create(testFile) require.NoError(t, err) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) rPath := transfer.GetRealFsPath(testFile) assert.Equal(t, testFile, rPath) rPath = conn.getRealFsPath(testFile) assert.Equal(t, testFile, rPath) err = transfer.Close() assert.NoError(t, err) err = file.Close() assert.NoError(t, err) transfer.File = nil rPath = transfer.GetRealFsPath(testFile) assert.Equal(t, testFile, rPath) rPath = transfer.GetRealFsPath("") assert.Empty(t, rPath) err = os.Remove(testFile) assert.NoError(t, err) assert.Len(t, conn.GetTransfers(), 0) } func TestTruncate(t *testing.T) { testFile := filepath.Join(os.TempDir(), "transfer_test_file") fs := vfs.NewOsFs("123", os.TempDir(), "", nil) u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user", HomeDir: os.TempDir(), }, } u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} file, err := os.Create(testFile) if !assert.NoError(t, err) { assert.FailNow(t, "unable to open test file") } _, err = file.Write([]byte("hello")) assert.NoError(t, err) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, 100, 0, false, fs, dataprovider.TransferQuota{}) err = conn.SetStat("/transfer_test_file", &StatAttributes{ Size: 2, Flags: StatAttrSize, }) assert.NoError(t, err) assert.Equal(t, int64(103), transfer.MaxWriteSize) err = transfer.Close() assert.NoError(t, err) err = file.Close() assert.NoError(t, err) fi, err := os.Stat(testFile) if assert.NoError(t, err) { assert.Equal(t, int64(2), fi.Size()) } transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, 0, true, fs, dataprovider.TransferQuota{}) // file.Stat will fail on a closed file err = conn.SetStat("/transfer_test_file", &StatAttributes{ Size: 2, Flags: StatAttrSize, }) assert.Error(t, err) err = transfer.Close() assert.NoError(t, err) transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) _, err = transfer.Truncate("mismatch", 0) assert.EqualError(t, err, errTransferMismatch.Error()) _, err = transfer.Truncate(testFile, 0) assert.NoError(t, err) _, err = transfer.Truncate(testFile, 1) assert.EqualError(t, err, vfs.ErrVfsUnsupported.Error()) err = transfer.Close() assert.NoError(t, err) err = os.Remove(testFile) assert.NoError(t, err) assert.Len(t, conn.GetTransfers(), 0) } func TestTransferErrors(t *testing.T) { isCancelled := false cancelFn := func() { isCancelled = true } testFile := filepath.Join(os.TempDir(), "transfer_test_file") fs := vfs.NewOsFs("id", os.TempDir(), "", nil) u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test", HomeDir: os.TempDir(), }, } err := os.WriteFile(testFile, []byte("test data"), os.ModePerm) assert.NoError(t, err) file, err := os.Open(testFile) if !assert.NoError(t, err) { assert.FailNow(t, "unable to open test file") } conn := NewBaseConnection("id", ProtocolSFTP, "", "", u) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) pathError := &os.PathError{ Op: "test", Path: testFile, Err: os.ErrInvalid, } err = transfer.ConvertError(pathError) assert.EqualError(t, err, fmt.Sprintf("%s %s: %s", pathError.Op, "/transfer_test_file", pathError.Err.Error())) err = transfer.ConvertError(os.ErrNotExist) assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile) err = transfer.ConvertError(os.ErrPermission) assert.ErrorIs(t, err, sftp.ErrSSHFxPermissionDenied) assert.Nil(t, transfer.cancelFn) assert.Equal(t, testFile, transfer.GetFsPath()) transfer.SetMetadata(map[string]string{"key": "val"}) transfer.SetCancelFn(cancelFn) errFake := errors.New("err fake") transfer.BytesReceived.Store(9) transfer.TransferError(ErrQuotaExceeded) assert.True(t, isCancelled) transfer.TransferError(errFake) assert.Error(t, transfer.ErrTransfer, ErrQuotaExceeded.Error()) // the file is closed from the embedding struct before to call close err = file.Close() assert.NoError(t, err) err = transfer.Close() if assert.Error(t, err) { assert.Error(t, err, ErrQuotaExceeded.Error()) } assert.NoFileExists(t, testFile) err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) assert.NoError(t, err) file, err = os.Open(testFile) if !assert.NoError(t, err) { assert.FailNow(t, "unable to open test file") } fsPath := filepath.Join(os.TempDir(), "test_file") transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer.BytesReceived.Store(9) transfer.TransferError(errFake) assert.Error(t, transfer.ErrTransfer, errFake.Error()) // the file is closed from the embedding struct before to call close err = file.Close() assert.NoError(t, err) err = transfer.Close() if assert.Error(t, err) { assert.Error(t, err, errFake.Error()) } assert.NoFileExists(t, testFile) err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) assert.NoError(t, err) file, err = os.Open(testFile) if !assert.NoError(t, err) { assert.FailNow(t, "unable to open test file") } transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer.BytesReceived.Store(9) // the file is closed from the embedding struct before to call close err = file.Close() assert.NoError(t, err) err = transfer.Close() assert.NoError(t, err) assert.NoFileExists(t, testFile) assert.FileExists(t, fsPath) err = os.Remove(fsPath) assert.NoError(t, err) assert.Len(t, conn.GetTransfers(), 0) } func TestRemovePartialCryptoFile(t *testing.T) { testFile := filepath.Join(os.TempDir(), "transfer_test_file") fs, err := vfs.NewCryptFs("id", os.TempDir(), "", vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")}) require.NoError(t, err) u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test", HomeDir: os.TempDir(), QuotaFiles: 1000000, }, } conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer.ErrTransfer = errors.New("test error") _, _, err = transfer.getUploadFileSize() assert.Error(t, err) err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) assert.NoError(t, err) size, deletedFiles, err := transfer.getUploadFileSize() assert.NoError(t, err) assert.Equal(t, int64(0), size) assert.Equal(t, 1, deletedFiles) assert.NoFileExists(t, testFile) err = transfer.Close() assert.Error(t, err) assert.Len(t, conn.GetTransfers(), 0) } func TestFTPMode(t *testing.T) { conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{}) transfer := BaseTransfer{ Connection: conn, transferType: TransferUpload, Fs: vfs.NewOsFs("", os.TempDir(), "", nil), } transfer.BytesReceived.Store(123) assert.Empty(t, transfer.ftpMode) transfer.SetFtpMode("active") assert.Equal(t, "active", transfer.ftpMode) } func TestTransferQuota(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ TotalDataTransfer: 3, UploadDataTransfer: 2, DownloadDataTransfer: 1, }, } ul, dl, total := user.GetDataTransferLimits() assert.Equal(t, int64(2*1048576), ul) assert.Equal(t, int64(1*1048576), dl) assert.Equal(t, int64(3*1048576), total) user.TotalDataTransfer = -1 user.UploadDataTransfer = -1 user.DownloadDataTransfer = -1 ul, dl, total = user.GetDataTransferLimits() assert.Equal(t, int64(0), ul) assert.Equal(t, int64(0), dl) assert.Equal(t, int64(0), total) transferQuota := dataprovider.TransferQuota{} assert.True(t, transferQuota.HasDownloadSpace()) assert.True(t, transferQuota.HasUploadSpace()) transferQuota.TotalSize = -1 transferQuota.ULSize = -1 transferQuota.DLSize = -1 assert.True(t, transferQuota.HasDownloadSpace()) assert.True(t, transferQuota.HasUploadSpace()) transferQuota.TotalSize = 100 transferQuota.AllowedTotalSize = 10 assert.True(t, transferQuota.HasDownloadSpace()) assert.True(t, transferQuota.HasUploadSpace()) transferQuota.AllowedTotalSize = 0 assert.False(t, transferQuota.HasDownloadSpace()) assert.False(t, transferQuota.HasUploadSpace()) transferQuota.TotalSize = 0 transferQuota.DLSize = 100 transferQuota.ULSize = 50 transferQuota.AllowedTotalSize = 0 assert.False(t, transferQuota.HasDownloadSpace()) assert.False(t, transferQuota.HasUploadSpace()) transferQuota.AllowedDLSize = 1 transferQuota.AllowedULSize = 1 assert.True(t, transferQuota.HasDownloadSpace()) assert.True(t, transferQuota.HasUploadSpace()) transferQuota.AllowedDLSize = -10 transferQuota.AllowedULSize = -1 assert.False(t, transferQuota.HasDownloadSpace()) assert.False(t, transferQuota.HasUploadSpace()) conn := NewBaseConnection("", ProtocolSFTP, "", "", user) transfer := NewBaseTransfer(nil, conn, nil, "file.txt", "file.txt", "/transfer_test_file", TransferUpload, 0, 0, 0, 0, true, vfs.NewOsFs("", os.TempDir(), "", nil), dataprovider.TransferQuota{}) err := transfer.CheckRead() assert.NoError(t, err) err = transfer.CheckWrite() assert.NoError(t, err) transfer.transferQuota = dataprovider.TransferQuota{ AllowedTotalSize: 10, } transfer.BytesReceived.Store(5) transfer.BytesSent.Store(4) err = transfer.CheckRead() assert.NoError(t, err) err = transfer.CheckWrite() assert.NoError(t, err) transfer.BytesSent.Store(6) err = transfer.CheckRead() if assert.Error(t, err) { assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) } err = transfer.CheckWrite() assert.True(t, conn.IsQuotaExceededError(err)) transferQuota = dataprovider.TransferQuota{ AllowedTotalSize: 0, AllowedULSize: 10, AllowedDLSize: 5, } transfer.transferQuota = transferQuota assert.Equal(t, transferQuota, transfer.GetTransferQuota()) err = transfer.CheckRead() if assert.Error(t, err) { assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) } err = transfer.CheckWrite() assert.NoError(t, err) transfer.BytesReceived.Store(11) err = transfer.CheckRead() if assert.Error(t, err) { assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) } err = transfer.CheckWrite() assert.True(t, conn.IsQuotaExceededError(err)) err = transfer.Close() assert.NoError(t, err) assert.Len(t, conn.GetTransfers(), 0) assert.Equal(t, int32(0), Connections.GetTotalTransfers()) } func TestUploadOutsideHomeRenameError(t *testing.T) { oldTempPath := Config.TempPath conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{}) transfer := BaseTransfer{ Connection: conn, transferType: TransferUpload, Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), "", nil), } transfer.BytesReceived.Store(123) fileName := filepath.Join(os.TempDir(), "_temp") err := os.WriteFile(fileName, []byte(`data`), 0644) assert.NoError(t, err) transfer.effectiveFsPath = fileName res := transfer.checkUploadOutsideHomeDir(os.ErrPermission) assert.Equal(t, 0, res) Config.TempPath = filepath.Clean(os.TempDir()) res = transfer.checkUploadOutsideHomeDir(nil) assert.Equal(t, 0, res) assert.Greater(t, transfer.BytesReceived.Load(), int64(0)) res = transfer.checkUploadOutsideHomeDir(os.ErrPermission) assert.Equal(t, 1, res) assert.Equal(t, int64(0), transfer.BytesReceived.Load()) assert.NoFileExists(t, fileName) Config.TempPath = oldTempPath } ================================================ FILE: internal/common/transferschecker.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "errors" "sync" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) type overquotaTransfer struct { ConnID string TransferID int64 TransferType int } type uploadAggregationKey struct { Username string FolderName string } // TransfersChecker defines the interface that transfer checkers must implement. // A transfer checker ensure that multiple concurrent transfers does not exceeded // the remaining user quota type TransfersChecker interface { AddTransfer(transfer dataprovider.ActiveTransfer) RemoveTransfer(ID int64, connectionID string) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) GetOverquotaTransfers() []overquotaTransfer } func getTransfersChecker(isShared int) TransfersChecker { if isShared == 1 { logger.Info(logSender, "", "using provider transfer checker") return &transfersCheckerDB{} } logger.Info(logSender, "", "using memory transfer checker") return &transfersCheckerMem{} } type baseTransferChecker struct { transfers []dataprovider.ActiveTransfer } func (t *baseTransferChecker) isDataTransferExceeded(user dataprovider.User, transfer dataprovider.ActiveTransfer, ulSize, dlSize int64, ) bool { ulQuota, dlQuota, totalQuota := user.GetDataTransferLimits() if totalQuota > 0 { allowedSize := totalQuota - (user.UsedUploadDataTransfer + user.UsedDownloadDataTransfer) if ulSize+dlSize > allowedSize { return transfer.CurrentDLSize > 0 || transfer.CurrentULSize > 0 } } if dlQuota > 0 { allowedSize := dlQuota - user.UsedDownloadDataTransfer if dlSize > allowedSize { return transfer.CurrentDLSize > 0 } } if ulQuota > 0 { allowedSize := ulQuota - user.UsedUploadDataTransfer if ulSize > allowedSize { return transfer.CurrentULSize > 0 } } return false } func (t *baseTransferChecker) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) { var result int64 if folderName != "" { for _, folder := range user.VirtualFolders { if folder.Name == folderName { if folder.QuotaSize > 0 { return folder.QuotaSize - folder.UsedQuotaSize, nil } } } } else { if user.QuotaSize > 0 { return user.QuotaSize - user.UsedQuotaSize, nil } } return result, errors.New("no quota limit defined") } func (t *baseTransferChecker) aggregateTransfersByUser(usersToFetch map[string]bool, ) (map[string]bool, map[string][]dataprovider.ActiveTransfer) { aggregations := make(map[string][]dataprovider.ActiveTransfer) for _, transfer := range t.transfers { aggregations[transfer.Username] = append(aggregations[transfer.Username], transfer) if len(aggregations[transfer.Username]) > 1 { if _, ok := usersToFetch[transfer.Username]; !ok { usersToFetch[transfer.Username] = false } } } return usersToFetch, aggregations } func (t *baseTransferChecker) aggregateUploadTransfers() (map[string]bool, map[int][]dataprovider.ActiveTransfer) { usersToFetch := make(map[string]bool) aggregations := make(map[int][]dataprovider.ActiveTransfer) var keys []uploadAggregationKey for _, transfer := range t.transfers { if transfer.Type != TransferUpload { continue } key := -1 for idx, k := range keys { if k.Username == transfer.Username && k.FolderName == transfer.FolderName { key = idx break } } if key == -1 { key = len(keys) } keys = append(keys, uploadAggregationKey{ Username: transfer.Username, FolderName: transfer.FolderName, }) aggregations[key] = append(aggregations[key], transfer) if len(aggregations[key]) > 1 { if transfer.FolderName != "" { usersToFetch[transfer.Username] = true } else { if _, ok := usersToFetch[transfer.Username]; !ok { usersToFetch[transfer.Username] = false } } } } return usersToFetch, aggregations } func (t *baseTransferChecker) getUsersToCheck(usersToFetch map[string]bool) (map[string]dataprovider.User, error) { users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch) if err != nil { return nil, err } usersMap := make(map[string]dataprovider.User) for _, user := range users { usersMap[user.Username] = user } return usersMap, nil } func (t *baseTransferChecker) getOverquotaTransfers(usersToFetch map[string]bool, uploadAggregations map[int][]dataprovider.ActiveTransfer, userAggregations map[string][]dataprovider.ActiveTransfer, ) []overquotaTransfer { if len(usersToFetch) == 0 { return nil } usersMap, err := t.getUsersToCheck(usersToFetch) if err != nil { logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err) return nil } var overquotaTransfers []overquotaTransfer for _, transfers := range uploadAggregations { username := transfers[0].Username folderName := transfers[0].FolderName remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName) if err != nil { continue } var usedDiskQuota int64 for _, tr := range transfers { // We optimistically assume that a cloud transfer that replaces an existing // file will be successful usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize } logger.Debug(logSender, "", "username %q, folder %q, concurrent transfers: %v, remaining disk quota (bytes): %v, disk quota used in ongoing transfers (bytes): %v", username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota) if usedDiskQuota > remaningDiskQuota { for _, tr := range transfers { if tr.CurrentULSize > tr.TruncatedSize { overquotaTransfers = append(overquotaTransfers, overquotaTransfer{ ConnID: tr.ConnID, TransferID: tr.ID, TransferType: tr.Type, }) } } } } for username, transfers := range userAggregations { var ulSize, dlSize int64 for _, tr := range transfers { ulSize += tr.CurrentULSize dlSize += tr.CurrentDLSize } logger.Debug(logSender, "", "username %q, concurrent transfers: %v, quota (bytes) used in ongoing transfers, ul: %v, dl: %v", username, len(transfers), ulSize, dlSize) for _, tr := range transfers { if t.isDataTransferExceeded(usersMap[username], tr, ulSize, dlSize) { overquotaTransfers = append(overquotaTransfers, overquotaTransfer{ ConnID: tr.ConnID, TransferID: tr.ID, TransferType: tr.Type, }) } } } return overquotaTransfers } type transfersCheckerMem struct { sync.RWMutex baseTransferChecker } func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) { t.Lock() defer t.Unlock() t.transfers = append(t.transfers, transfer) } func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) { t.Lock() defer t.Unlock() for idx, transfer := range t.transfers { if transfer.ID == ID && transfer.ConnID == connectionID { lastIdx := len(t.transfers) - 1 t.transfers[idx] = t.transfers[lastIdx] t.transfers = t.transfers[:lastIdx] return } } } func (t *transfersCheckerMem) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) { t.Lock() defer t.Unlock() for idx := range t.transfers { if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID { t.transfers[idx].CurrentDLSize = dlSize t.transfers[idx].CurrentULSize = ulSize t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) return } } } func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer { t.RLock() usersToFetch, uploadAggregations := t.aggregateUploadTransfers() usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch) t.RUnlock() return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations) } type transfersCheckerDB struct { baseTransferChecker lastCleanup time.Time } func (t *transfersCheckerDB) AddTransfer(transfer dataprovider.ActiveTransfer) { dataprovider.AddActiveTransfer(transfer) } func (t *transfersCheckerDB) RemoveTransfer(ID int64, connectionID string) { dataprovider.RemoveActiveTransfer(ID, connectionID) } func (t *transfersCheckerDB) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) { dataprovider.UpdateActiveTransferSizes(ulSize, dlSize, ID, connectionID) } func (t *transfersCheckerDB) GetOverquotaTransfers() []overquotaTransfer { if t.lastCleanup.IsZero() || t.lastCleanup.Add(periodicTimeoutCheckInterval*15).Before(time.Now()) { before := time.Now().Add(-periodicTimeoutCheckInterval * 5) err := dataprovider.CleanupActiveTransfers(before) logger.Debug(logSender, "", "cleanup active transfers completed, err: %v", err) if err == nil { t.lastCleanup = time.Now() } } var err error from := time.Now().Add(-periodicTimeoutCheckInterval * 2) t.transfers, err = dataprovider.GetActiveTransfers(from) if err != nil { logger.Error(logSender, "", "unable to check overquota transfers, error getting active transfers: %v", err) return nil } usersToFetch, uploadAggregations := t.aggregateUploadTransfers() usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch) return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations) } ================================================ FILE: internal/common/transferschecker_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package common import ( "fmt" "os" "path" "path/filepath" "strconv" "strings" "testing" "time" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) func TestTransfersCheckerDiskQuota(t *testing.T) { username := "transfers_check_username" folderName := "test_transfers_folder" groupName := "test_transfers_group" vdirPath := "/vdir" group := dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: groupName, }, UserSettings: dataprovider.GroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{ QuotaSize: 120, }, }, } folder := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: filepath.Join(os.TempDir(), folderName), } user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: "testpwd", HomeDir: filepath.Join(os.TempDir(), username), Status: 1, QuotaSize: 0, // the quota size defined for the group is used Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, VirtualFolders: []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, QuotaSize: 100, }, }, Groups: []sdk.GroupMapping{ { Name: groupName, Type: sdk.GroupTypePrimary, }, }, } err := dataprovider.AddGroup(&group, "", "", "") assert.NoError(t, err) group, err = dataprovider.GroupExists(groupName) assert.NoError(t, err) err = dataprovider.AddFolder(&folder, "", "", "") assert.NoError(t, err) assert.Equal(t, int64(120), group.UserSettings.QuotaSize) err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) user, err = dataprovider.GetUserWithGroupSettings(username, "") assert.NoError(t, err) connID1 := xid.New().String() fsUser, err := user.GetFilesystemForPath("/file1", connID1) assert.NoError(t, err) conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "", user) fakeConn1 := &fakeConnection{ BaseConnection: conn1, } transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) transfer1.BytesReceived.Store(150) err = Connections.Add(fakeConn1) assert.NoError(t, err) // the transferschecker will do nothing if there is only one ongoing transfer Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) connID2 := xid.New().String() conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "", user) fakeConn2 := &fakeConnection{ BaseConnection: conn2, } transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), "/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{}) transfer1.BytesReceived.Store(50) transfer2.BytesReceived.Store(60) err = Connections.Add(fakeConn2) assert.NoError(t, err) connID3 := xid.New().String() conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user) fakeConn3 := &fakeConnection{ BaseConnection: conn3, } transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"), "/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) transfer3.BytesReceived.Store(60) // this value will be ignored, this is a download err = Connections.Add(fakeConn3) assert.NoError(t, err) // the transfers are not overquota Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) transfer1.BytesReceived.Store(80) // truncated size will be subtracted, we are not overquota Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) transfer1.BytesReceived.Store(120) // we are now overquota // if another check is in progress nothing is done Connections.transfersCheckStatus.Store(true) Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) Connections.transfersCheckStatus.Store(false) Connections.checkTransfers() assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort) assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort), transfer2.errAbort) assert.True(t, conn1.IsQuotaExceededError(transfer1.GetAbortError())) assert.Nil(t, transfer3.errAbort) assert.True(t, conn3.IsQuotaExceededError(transfer3.GetAbortError())) // update the user quota size group.UserSettings.QuotaSize = 1000 err = dataprovider.UpdateGroup(&group, []string{username}, "", "", "") assert.NoError(t, err) transfer1.errAbort = nil transfer2.errAbort = nil Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) group.UserSettings.QuotaSize = 0 err = dataprovider.UpdateGroup(&group, []string{username}, "", "", "") assert.NoError(t, err) Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) // now check a public folder transfer1.BytesReceived.Store(0) transfer2.BytesReceived.Store(0) connID4 := xid.New().String() fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4) assert.NoError(t, err) conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user) fakeConn4 := &fakeConnection{ BaseConnection: conn4, } transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"), filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0, 100, 0, true, fsFolder, dataprovider.TransferQuota{}) err = Connections.Add(fakeConn4) assert.NoError(t, err) connID5 := xid.New().String() conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user) fakeConn5 := &fakeConnection{ BaseConnection: conn5, } transfer5 := NewBaseTransfer(nil, conn5, nil, filepath.Join(os.TempDir(), folderName, "file2"), filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0, 100, 0, true, fsFolder, dataprovider.TransferQuota{}) err = Connections.Add(fakeConn5) assert.NoError(t, err) transfer4.BytesReceived.Store(50) transfer5.BytesReceived.Store(40) Connections.checkTransfers() assert.Nil(t, transfer4.errAbort) assert.Nil(t, transfer5.errAbort) transfer5.BytesReceived.Store(60) Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) assert.True(t, conn1.IsQuotaExceededError(transfer4.errAbort)) assert.True(t, conn2.IsQuotaExceededError(transfer5.errAbort)) if dataprovider.GetProviderStatus().Driver != dataprovider.MemoryDataProviderName { providerConf := dataprovider.GetProviderConfig() err = dataprovider.Close() assert.NoError(t, err) transfer4.errAbort = nil transfer5.errAbort = nil Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer4.errAbort) assert.Nil(t, transfer5.errAbort) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } err = transfer1.Close() assert.NoError(t, err) err = transfer2.Close() assert.NoError(t, err) err = transfer3.Close() assert.NoError(t, err) err = transfer4.Close() assert.NoError(t, err) err = transfer5.Close() assert.NoError(t, err) Connections.Remove(fakeConn1.GetID()) Connections.Remove(fakeConn2.GetID()) Connections.Remove(fakeConn3.GetID()) Connections.Remove(fakeConn4.GetID()) Connections.Remove(fakeConn5.GetID()) stats := Connections.GetStats("") assert.Len(t, stats, 0) assert.Equal(t, int32(0), Connections.GetTotalTransfers()) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.DeleteFolder(folderName, "", "", "") assert.NoError(t, err) err = os.RemoveAll(filepath.Join(os.TempDir(), folderName)) assert.NoError(t, err) err = dataprovider.DeleteGroup(groupName, "", "", "") assert.NoError(t, err) } func TestTransferCheckerTransferQuota(t *testing.T) { username := "transfers_check_username" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: "test_pwd", HomeDir: filepath.Join(os.TempDir(), username), Status: 1, TotalDataTransfer: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) connID1 := xid.New().String() fsUser, err := user.GetFilesystemForPath("/file1", connID1) assert.NoError(t, err) conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "192.168.1.1", user) fakeConn1 := &fakeConnection{ BaseConnection: conn1, } transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) transfer1.BytesReceived.Store(150) err = Connections.Add(fakeConn1) assert.NoError(t, err) // the transferschecker will do nothing if there is only one ongoing transfer Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) connID2 := xid.New().String() conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "127.0.0.1", user) fakeConn2 := &fakeConnection{ BaseConnection: conn2, } transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), "/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) transfer2.BytesReceived.Store(150) err = Connections.Add(fakeConn2) assert.NoError(t, err) Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) // now test overquota transfer1.BytesReceived.Store(1024*1024 + 1) transfer2.BytesReceived.Store(0) Connections.checkTransfers() assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort) assert.Nil(t, transfer2.errAbort) transfer1.errAbort = nil transfer1.BytesReceived.Store(1024*1024 + 1) transfer2.BytesReceived.Store(1024) Connections.checkTransfers() assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort)) transfer1.BytesReceived.Store(0) transfer2.BytesReceived.Store(0) transfer1.errAbort = nil transfer2.errAbort = nil err = transfer1.Close() assert.NoError(t, err) err = transfer2.Close() assert.NoError(t, err) Connections.Remove(fakeConn1.GetID()) Connections.Remove(fakeConn2.GetID()) connID3 := xid.New().String() conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user) fakeConn3 := &fakeConnection{ BaseConnection: conn3, } transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) transfer3.BytesSent.Store(150) err = Connections.Add(fakeConn3) assert.NoError(t, err) connID4 := xid.New().String() conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user) fakeConn4 := &fakeConnection{ BaseConnection: conn4, } transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), "/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) transfer4.BytesSent.Store(150) err = Connections.Add(fakeConn4) assert.NoError(t, err) Connections.checkTransfers() assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer4.errAbort) transfer3.BytesSent.Store(512 * 1024) transfer4.BytesSent.Store(512*1024 + 1) Connections.checkTransfers() if assert.Error(t, transfer3.errAbort) { assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error()) } if assert.Error(t, transfer4.errAbort) { assert.Contains(t, transfer4.errAbort.Error(), ErrReadQuotaExceeded.Error()) } err = transfer3.Close() assert.NoError(t, err) err = transfer4.Close() assert.NoError(t, err) Connections.Remove(fakeConn3.GetID()) Connections.Remove(fakeConn4.GetID()) stats := Connections.GetStats("") assert.Len(t, stats, 0) assert.Equal(t, int32(0), Connections.GetTotalTransfers()) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestAggregateTransfers(t *testing.T) { checker := transfersCheckerMem{} checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, Type: TransferUpload, ConnID: "1", Username: "user", FolderName: "", TruncatedSize: 0, CurrentULSize: 100, CurrentDLSize: 0, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) usersToFetch, aggregations := checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 0) assert.Len(t, aggregations, 1) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, Type: TransferDownload, ConnID: "2", Username: "user", FolderName: "", TruncatedSize: 0, CurrentULSize: 0, CurrentDLSize: 100, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 0) assert.Len(t, aggregations, 1) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, Type: TransferUpload, ConnID: "3", Username: "user", FolderName: "folder", TruncatedSize: 0, CurrentULSize: 10, CurrentDLSize: 0, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 0) assert.Len(t, aggregations, 2) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, Type: TransferUpload, ConnID: "4", Username: "user1", FolderName: "", TruncatedSize: 0, CurrentULSize: 100, CurrentDLSize: 0, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 0) assert.Len(t, aggregations, 3) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, Type: TransferUpload, ConnID: "5", Username: "user", FolderName: "", TruncatedSize: 0, CurrentULSize: 100, CurrentDLSize: 0, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 1) val, ok := usersToFetch["user"] assert.True(t, ok) assert.False(t, val) assert.Len(t, aggregations, 3) aggregate, ok := aggregations[0] assert.True(t, ok) assert.Len(t, aggregate, 2) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, Type: TransferUpload, ConnID: "6", Username: "user", FolderName: "", TruncatedSize: 0, CurrentULSize: 100, CurrentDLSize: 0, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 1) val, ok = usersToFetch["user"] assert.True(t, ok) assert.False(t, val) assert.Len(t, aggregations, 3) aggregate, ok = aggregations[0] assert.True(t, ok) assert.Len(t, aggregate, 3) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, Type: TransferUpload, ConnID: "7", Username: "user", FolderName: "folder", TruncatedSize: 0, CurrentULSize: 10, CurrentDLSize: 0, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 1) val, ok = usersToFetch["user"] assert.True(t, ok) assert.True(t, val) assert.Len(t, aggregations, 3) aggregate, ok = aggregations[0] assert.True(t, ok) assert.Len(t, aggregate, 3) aggregate, ok = aggregations[1] assert.True(t, ok) assert.Len(t, aggregate, 2) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, Type: TransferUpload, ConnID: "8", Username: "user", FolderName: "", TruncatedSize: 0, CurrentULSize: 100, CurrentDLSize: 0, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 1) val, ok = usersToFetch["user"] assert.True(t, ok) assert.True(t, val) assert.Len(t, aggregations, 3) aggregate, ok = aggregations[0] assert.True(t, ok) assert.Len(t, aggregate, 4) aggregate, ok = aggregations[1] assert.True(t, ok) assert.Len(t, aggregate, 2) } func TestDataTransferExceeded(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ TotalDataTransfer: 1, }, } transfer := dataprovider.ActiveTransfer{ CurrentULSize: 0, CurrentDLSize: 0, } user.UsedDownloadDataTransfer = 1024 * 1024 user.UsedUploadDataTransfer = 512 * 1024 checker := transfersCheckerMem{} res := checker.isDataTransferExceeded(user, transfer, 100, 100) assert.False(t, res) transfer.CurrentULSize = 1 res = checker.isDataTransferExceeded(user, transfer, 100, 100) assert.True(t, res) user.UsedDownloadDataTransfer = 512*1024 - 100 user.UsedUploadDataTransfer = 512*1024 - 100 res = checker.isDataTransferExceeded(user, transfer, 100, 100) assert.False(t, res) res = checker.isDataTransferExceeded(user, transfer, 101, 100) assert.True(t, res) user.TotalDataTransfer = 0 user.DownloadDataTransfer = 1 user.UsedDownloadDataTransfer = 512 * 1024 transfer.CurrentULSize = 0 transfer.CurrentDLSize = 100 res = checker.isDataTransferExceeded(user, transfer, 0, 512*1024) assert.False(t, res) res = checker.isDataTransferExceeded(user, transfer, 0, 512*1024+1) assert.True(t, res) user.DownloadDataTransfer = 0 user.UploadDataTransfer = 1 user.UsedUploadDataTransfer = 512 * 1024 transfer.CurrentULSize = 0 transfer.CurrentDLSize = 0 res = checker.isDataTransferExceeded(user, transfer, 512*1024+1, 0) assert.False(t, res) transfer.CurrentULSize = 1 res = checker.isDataTransferExceeded(user, transfer, 512*1024+1, 0) assert.True(t, res) } func TestGetUsersForQuotaCheck(t *testing.T) { usersToFetch := make(map[string]bool) for i := 0; i < 70; i++ { usersToFetch[fmt.Sprintf("user%v", i)] = i%2 == 0 } users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch) assert.NoError(t, err) assert.Len(t, users, 0) for i := 0; i < 60; i++ { folder := vfs.BaseVirtualFolder{ Name: fmt.Sprintf("f%v", i), MappedPath: filepath.Join(os.TempDir(), fmt.Sprintf("f%v", i)), } user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: fmt.Sprintf("user%v", i), Password: "pwd", HomeDir: filepath.Join(os.TempDir(), fmt.Sprintf("user%v", i)), Status: 1, QuotaSize: 120, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, VirtualFolders: []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder.Name, }, VirtualPath: "/vfolder", QuotaSize: 100, }, }, } err = dataprovider.AddFolder(&folder, "", "", "") assert.NoError(t, err) err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) err = dataprovider.UpdateVirtualFolderQuota(&vfs.BaseVirtualFolder{Name: fmt.Sprintf("f%v", i)}, 1, 50, false) assert.NoError(t, err) } users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch) assert.NoError(t, err) assert.Len(t, users, 60) for _, user := range users { userIdxStr := strings.Replace(user.Username, "user", "", 1) userIdx, err := strconv.Atoi(userIdxStr) assert.NoError(t, err) if userIdx%2 == 0 { if assert.Len(t, user.VirtualFolders, 1, user.Username) { assert.Equal(t, int64(100), user.VirtualFolders[0].QuotaSize) assert.Equal(t, int64(50), user.VirtualFolders[0].UsedQuotaSize) } } else { switch dataprovider.GetProviderStatus().Driver { case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: assert.Len(t, user.VirtualFolders, 0, user.Username) } } ul, dl, total := user.GetDataTransferLimits() assert.Equal(t, int64(0), ul) assert.Equal(t, int64(0), dl) assert.Equal(t, int64(0), total) } for i := 0; i < 60; i++ { err = dataprovider.DeleteUser(fmt.Sprintf("user%v", i), "", "", "") assert.NoError(t, err) err = dataprovider.DeleteFolder(fmt.Sprintf("f%v", i), "", "", "") assert.NoError(t, err) } users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch) assert.NoError(t, err) assert.Len(t, users, 0) } func TestDBTransferChecker(t *testing.T) { if !isDbTransferCheckerSupported() { t.Skip("this test is not supported with the current database provider") } providerConf := dataprovider.GetProviderConfig() err := dataprovider.Close() assert.NoError(t, err) providerConf.IsShared = 1 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) c := getTransfersChecker(1) checker, ok := c.(*transfersCheckerDB) assert.True(t, ok) assert.True(t, checker.lastCleanup.IsZero()) transfer1 := dataprovider.ActiveTransfer{ ID: 1, Type: TransferDownload, ConnID: xid.New().String(), Username: "user1", FolderName: "folder1", IP: "127.0.0.1", } checker.AddTransfer(transfer1) transfers, err := dataprovider.GetActiveTransfers(time.Now().Add(24 * time.Hour)) assert.NoError(t, err) assert.Len(t, transfers, 0) transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) assert.NoError(t, err) var createdAt, updatedAt int64 if assert.Len(t, transfers, 1) { transfer := transfers[0] assert.Equal(t, transfer1.ID, transfer.ID) assert.Equal(t, transfer1.Type, transfer.Type) assert.Equal(t, transfer1.ConnID, transfer.ConnID) assert.Equal(t, transfer1.Username, transfer.Username) assert.Equal(t, transfer1.IP, transfer.IP) assert.Equal(t, transfer1.FolderName, transfer.FolderName) assert.Greater(t, transfer.CreatedAt, int64(0)) assert.Greater(t, transfer.UpdatedAt, int64(0)) assert.Equal(t, int64(0), transfer.CurrentDLSize) assert.Equal(t, int64(0), transfer.CurrentULSize) createdAt = transfer.CreatedAt updatedAt = transfer.UpdatedAt } time.Sleep(100 * time.Millisecond) checker.UpdateTransferCurrentSizes(100, 150, transfer1.ID, transfer1.ConnID) transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) assert.NoError(t, err) if assert.Len(t, transfers, 1) { transfer := transfers[0] assert.Equal(t, int64(150), transfer.CurrentDLSize) assert.Equal(t, int64(100), transfer.CurrentULSize) assert.Equal(t, createdAt, transfer.CreatedAt) assert.Greater(t, transfer.UpdatedAt, updatedAt) } res := checker.GetOverquotaTransfers() assert.Len(t, res, 0) checker.RemoveTransfer(transfer1.ID, transfer1.ConnID) transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) assert.NoError(t, err) assert.Len(t, transfers, 0) err = dataprovider.Close() assert.NoError(t, err) res = checker.GetOverquotaTransfers() assert.Len(t, res, 0) providerConf.IsShared = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func isDbTransferCheckerSupported() bool { // SQLite shares the implementation with other SQL-based provider but it makes no sense // to use it outside test cases switch dataprovider.GetProviderStatus().Driver { case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: return true default: return false } } ================================================ FILE: internal/config/config.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package config manages the configuration package config import ( "errors" "fmt" "os" "path/filepath" "slices" "strconv" "strings" kmsplugin "github.com/sftpgo/sdk/plugin/kms" "github.com/spf13/viper" "github.com/subosito/gotenv" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/telemetry" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) const ( logSender = "config" // configName defines the name for config file. // This name does not include the extension, viper will search for files // with supported extensions such as "sftpgo.json", "sftpgo.yaml" and so on configName = "sftpgo" // ConfigEnvPrefix defines a prefix that environment variables will use configEnvPrefix = "sftpgo" envFileMaxSize = 1048576 ) var ( globalConf globalConfig defaultInstallCodeHint = "Installation code" defaultSFTPDBinding = sftpd.Binding{ Address: "", Port: 2022, ApplyProxyConfig: true, } defaultFTPDBinding = ftpd.Binding{ Address: "", Port: 0, ApplyProxyConfig: true, TLSMode: 0, CertificateFile: "", CertificateKeyFile: "", MinTLSVersion: 12, ForcePassiveIP: "", PassiveIPOverrides: nil, PassiveHost: "", ClientAuthType: 0, TLSCipherSuites: nil, PassiveConnectionsSecurity: 0, ActiveConnectionsSecurity: 0, Debug: false, } defaultWebDAVDBinding = webdavd.Binding{ Address: "", Port: 0, EnableHTTPS: false, CertificateFile: "", CertificateKeyFile: "", MinTLSVersion: 12, ClientAuthType: 0, TLSCipherSuites: nil, Protocols: nil, Prefix: "", ProxyMode: 0, ProxyAllowed: nil, ClientIPProxyHeader: "", ClientIPHeaderDepth: 0, DisableWWWAuthHeader: false, } defaultHTTPDBinding = httpd.Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: true, EnableRESTAPI: true, EnabledLoginMethods: 0, DisabledLoginMethods: 0, EnableHTTPS: false, CertificateFile: "", CertificateKeyFile: "", MinTLSVersion: 12, ClientAuthType: 0, TLSCipherSuites: nil, Protocols: nil, ProxyMode: 0, ProxyAllowed: nil, ClientIPProxyHeader: "", ClientIPHeaderDepth: 0, HideLoginURL: 0, RenderOpenAPI: true, BaseURL: "", Languages: []string{"en"}, OIDC: httpd.OIDC{ ClientID: "", ClientSecret: "", ClientSecretFile: "", ConfigURL: "", RedirectBaseURL: "", UsernameField: "", RoleField: "", ImplicitRoles: false, Scopes: []string{"openid", "profile", "email"}, CustomFields: []string{}, InsecureSkipSignatureCheck: false, Debug: false, }, Security: httpd.SecurityConf{ Enabled: false, AllowedHosts: nil, AllowedHostsAreRegex: false, HostsProxyHeaders: nil, HTTPSRedirect: false, HTTPSHost: "", HTTPSProxyHeaders: nil, STSSeconds: 0, STSIncludeSubdomains: false, STSPreload: false, ContentTypeNosniff: false, ContentSecurityPolicy: "", PermissionsPolicy: "", CrossOriginOpenerPolicy: "", CrossOriginResourcePolicy: "", CrossOriginEmbedderPolicy: "", CacheControl: "", }, Branding: httpd.Branding{}, } defaultRateLimiter = common.RateLimiterConfig{ Average: 0, Period: 1000, Burst: 1, Type: 2, Protocols: []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV, common.ProtocolHTTP}, GenerateDefenderEvents: false, EntriesSoftLimit: 100, EntriesHardLimit: 150, } defaultTOTP = mfa.TOTPConfig{ Name: "Default", Issuer: "SFTPGo", Algo: mfa.TOTPAlgoSHA1, } ) type globalConfig struct { Common common.Configuration `json:"common" mapstructure:"common"` ACME acme.Configuration `json:"acme" mapstructure:"acme"` SFTPD sftpd.Configuration `json:"sftpd" mapstructure:"sftpd"` FTPD ftpd.Configuration `json:"ftpd" mapstructure:"ftpd"` WebDAVD webdavd.Configuration `json:"webdavd" mapstructure:"webdavd"` ProviderConf dataprovider.Config `json:"data_provider" mapstructure:"data_provider"` HTTPDConfig httpd.Conf `json:"httpd" mapstructure:"httpd"` HTTPConfig httpclient.Config `json:"http" mapstructure:"http"` CommandConfig command.Config `json:"command" mapstructure:"command"` KMSConfig kms.Configuration `json:"kms" mapstructure:"kms"` MFAConfig mfa.Config `json:"mfa" mapstructure:"mfa"` TelemetryConfig telemetry.Conf `json:"telemetry" mapstructure:"telemetry"` PluginsConfig []plugin.Config `json:"plugins" mapstructure:"plugins"` SMTPConfig smtp.Config `json:"smtp" mapstructure:"smtp"` } func init() { Init() } // Init initializes the global configuration. // It is not supposed to be called outside of this package. // It is exported to minimize refactoring efforts. Will eventually disappear. func Init() { // create a default configuration to use if no config file is provided globalConf = globalConfig{ Common: common.Configuration{ IdleTimeout: 15, UploadMode: 0, Actions: common.ProtocolActions{ ExecuteOn: []string{}, ExecuteSync: []string{}, Hook: "", }, SetstatMode: 0, RenameMode: 0, ResumeMaxSize: 0, TempPath: "", ProxyProtocol: 0, ProxyAllowed: []string{}, ProxySkipped: []string{}, PostConnectHook: "", PostDisconnectHook: "", MaxTotalConnections: 0, MaxPerHostConnections: 20, AllowListStatus: 0, AllowSelfConnections: 0, DefenderConfig: common.DefenderConfig{ Enabled: false, Driver: common.DefenderDriverMemory, BanTime: 30, BanTimeIncrement: 50, Threshold: 15, ScoreInvalid: 2, ScoreValid: 1, ScoreLimitExceeded: 3, ScoreNoAuth: 0, ObservationTime: 30, EntriesSoftLimit: 100, EntriesHardLimit: 150, LoginDelay: common.LoginDelay{ Success: 0, PasswordFailed: 1000, }, }, RateLimitersConfig: []common.RateLimiterConfig{defaultRateLimiter}, Umask: "", ServerVersion: "", TZ: "", Metadata: common.MetadataConfig{ Read: 0, }, EventManager: common.EventManagerConfig{ EnabledCommands: []string{}, }, }, ACME: acme.Configuration{ Email: "", KeyType: "4096", CertsPath: "certs", CAEndpoint: "https://acme-v02.api.letsencrypt.org/directory", Domains: []string{}, RenewDays: 30, HTTP01Challenge: acme.HTTP01Challenge{ Port: 80, WebRoot: "", ProxyHeader: "", }, TLSALPN01Challenge: acme.TLSALPN01Challenge{ Port: 0, }, }, SFTPD: sftpd.Configuration{ Bindings: []sftpd.Binding{defaultSFTPDBinding}, MaxAuthTries: 0, HostKeys: []string{}, HostCertificates: []string{}, HostKeyAlgorithms: []string{}, KexAlgorithms: []string{}, Ciphers: []string{}, MACs: []string{}, PublicKeyAlgorithms: []string{}, TrustedUserCAKeys: []string{}, RevokedUserCertsFile: "", OPKSSHPath: "", OPKSSHChecksum: "", LoginBannerFile: "", EnabledSSHCommands: []string{}, KeyboardInteractiveAuthentication: true, KeyboardInteractiveHook: "", PasswordAuthentication: true, }, FTPD: ftpd.Configuration{ Bindings: []ftpd.Binding{defaultFTPDBinding}, BannerFile: "", ActiveTransfersPortNon20: true, PassivePortRange: ftpd.PortRange{ Start: 50000, End: 50100, }, DisableActiveMode: false, EnableSite: false, HASHSupport: 0, CombineSupport: 0, CertificateFile: "", CertificateKeyFile: "", CACertificates: []string{}, CARevocationLists: []string{}, }, WebDAVD: webdavd.Configuration{ Bindings: []webdavd.Binding{defaultWebDAVDBinding}, CertificateFile: "", CertificateKeyFile: "", CACertificates: []string{}, CARevocationLists: []string{}, Cors: webdavd.CorsConfig{ Enabled: false, AllowedOrigins: []string{}, AllowedMethods: []string{}, AllowedHeaders: []string{}, ExposedHeaders: []string{}, AllowCredentials: false, MaxAge: 0, OptionsPassthrough: false, OptionsSuccessStatus: 0, AllowPrivateNetwork: false, }, Cache: webdavd.Cache{ Users: webdavd.UsersCacheConfig{ ExpirationTime: 0, MaxSize: 50, }, MimeTypes: webdavd.MimeCacheConfig{ Enabled: true, MaxSize: 1000, CustomMappings: nil, }, }, }, ProviderConf: dataprovider.Config{ Driver: "sqlite", Name: "sftpgo.db", Host: "", Port: 0, Username: "", Password: "", ConnectionString: "", SQLTablesPrefix: "", SSLMode: 0, DisableSNI: false, TargetSessionAttrs: "", RootCert: "", ClientCert: "", ClientKey: "", TrackQuota: 2, PoolSize: 0, UsersBaseDir: "", Actions: dataprovider.ObjectsActions{ ExecuteOn: []string{}, ExecuteFor: []string{}, Hook: "", }, ExternalAuthHook: "", ExternalAuthScope: 0, PreLoginHook: "", PostLoginHook: "", PostLoginScope: 0, CheckPasswordHook: "", CheckPasswordScope: 0, PasswordHashing: dataprovider.PasswordHashing{ Argon2Options: dataprovider.Argon2Options{ Memory: 65536, Iterations: 1, Parallelism: 2, }, BcryptOptions: dataprovider.BcryptOptions{ Cost: 10, }, Algo: dataprovider.HashingAlgoBcrypt, }, PasswordValidation: dataprovider.PasswordValidation{ Admins: dataprovider.PasswordValidationRules{ MinEntropy: 0, }, Users: dataprovider.PasswordValidationRules{ MinEntropy: 0, }, }, PasswordCaching: true, UpdateMode: 0, DelayedQuotaUpdate: 0, CreateDefaultAdmin: false, NamingRules: 1, IsShared: 0, Node: dataprovider.NodeConfig{ Host: "", Port: 0, Proto: "http", }, BackupsPath: "backups", }, HTTPDConfig: httpd.Conf{ Bindings: []httpd.Binding{defaultHTTPDBinding}, TemplatesPath: "templates", StaticFilesPath: "static", OpenAPIPath: "openapi", WebRoot: "", CertificateFile: "", CertificateKeyFile: "", CACertificates: nil, CARevocationLists: nil, SigningPassphrase: "", SigningPassphraseFile: "", TokenValidation: 0, CookieLifetime: 20, ShareCookieLifetime: 120, JWTLifetime: 20, MaxUploadFileSize: 0, Cors: httpd.CorsConfig{ Enabled: false, AllowedOrigins: []string{}, AllowedMethods: []string{}, AllowedHeaders: []string{}, ExposedHeaders: []string{}, AllowCredentials: false, MaxAge: 0, OptionsPassthrough: false, OptionsSuccessStatus: 0, AllowPrivateNetwork: false, }, Setup: httpd.SetupConfig{ InstallationCode: "", InstallationCodeHint: defaultInstallCodeHint, }, HideSupportLink: false, }, HTTPConfig: httpclient.Config{ Timeout: 20, RetryWaitMin: 2, RetryWaitMax: 30, RetryMax: 3, CACertificates: nil, Certificates: nil, SkipTLSVerify: false, Headers: nil, }, CommandConfig: command.Config{ Timeout: 30, Env: nil, Commands: nil, }, KMSConfig: kms.Configuration{ Secrets: kms.Secrets{ URL: "", MasterKeyString: "", MasterKeyPath: "", }, }, MFAConfig: mfa.Config{ TOTP: []mfa.TOTPConfig{defaultTOTP}, }, TelemetryConfig: telemetry.Conf{ BindPort: 0, BindAddress: "127.0.0.1", EnableProfiler: false, AuthUserFile: "", CertificateFile: "", CertificateKeyFile: "", MinTLSVersion: 12, TLSCipherSuites: nil, Protocols: nil, }, SMTPConfig: smtp.Config{ Host: "", Port: 587, From: "", User: "", Password: "", AuthType: 0, Encryption: 0, Domain: "", TemplatesPath: "templates", }, PluginsConfig: nil, } viper.SetEnvPrefix(configEnvPrefix) replacer := strings.NewReplacer(".", "__") viper.SetEnvKeyReplacer(replacer) viper.SetConfigName(configName) setViperDefaults() viper.AutomaticEnv() viper.AllowEmptyEnv(true) } // GetCommonConfig returns the common protocols configuration func GetCommonConfig() common.Configuration { return globalConf.Common } // SetCommonConfig sets the common protocols configuration func SetCommonConfig(config common.Configuration) { globalConf.Common = config } // GetSFTPDConfig returns the configuration for the SFTP server func GetSFTPDConfig() sftpd.Configuration { return globalConf.SFTPD } // SetSFTPDConfig sets the configuration for the SFTP server func SetSFTPDConfig(config sftpd.Configuration) { globalConf.SFTPD = config } // GetFTPDConfig returns the configuration for the FTP server func GetFTPDConfig() ftpd.Configuration { return globalConf.FTPD } // SetFTPDConfig sets the configuration for the FTP server func SetFTPDConfig(config ftpd.Configuration) { globalConf.FTPD = config } // GetWebDAVDConfig returns the configuration for the WebDAV server func GetWebDAVDConfig() webdavd.Configuration { return globalConf.WebDAVD } // SetWebDAVDConfig sets the configuration for the WebDAV server func SetWebDAVDConfig(config webdavd.Configuration) { globalConf.WebDAVD = config } // GetHTTPDConfig returns the configuration for the HTTP server func GetHTTPDConfig() httpd.Conf { return globalConf.HTTPDConfig } // SetHTTPDConfig sets the configuration for the HTTP server func SetHTTPDConfig(config httpd.Conf) { globalConf.HTTPDConfig = config } // GetProviderConf returns the configuration for the data provider func GetProviderConf() dataprovider.Config { return globalConf.ProviderConf } // SetProviderConf sets the configuration for the data provider func SetProviderConf(config dataprovider.Config) { globalConf.ProviderConf = config } // GetHTTPConfig returns the configuration for HTTP clients func GetHTTPConfig() httpclient.Config { return globalConf.HTTPConfig } // GetCommandConfig returns the configuration for external commands func GetCommandConfig() command.Config { return globalConf.CommandConfig } // GetKMSConfig returns the KMS configuration func GetKMSConfig() kms.Configuration { return globalConf.KMSConfig } // SetKMSConfig sets the kms configuration func SetKMSConfig(config kms.Configuration) { globalConf.KMSConfig = config } // GetTelemetryConfig returns the telemetry configuration func GetTelemetryConfig() telemetry.Conf { return globalConf.TelemetryConfig } // SetTelemetryConfig sets the telemetry configuration func SetTelemetryConfig(config telemetry.Conf) { globalConf.TelemetryConfig = config } // GetPluginsConfig returns the plugins configuration func GetPluginsConfig() []plugin.Config { return globalConf.PluginsConfig } // SetPluginsConfig sets the plugin configuration func SetPluginsConfig(config []plugin.Config) { globalConf.PluginsConfig = config } // HasKMSPlugin returns true if at least one KMS plugin is configured. func HasKMSPlugin() bool { for _, c := range globalConf.PluginsConfig { if c.Type == kmsplugin.PluginName { return true } } return false } // GetMFAConfig returns multi-factor authentication config func GetMFAConfig() mfa.Config { return globalConf.MFAConfig } // GetSMTPConfig returns the SMTP configuration func GetSMTPConfig() smtp.Config { return globalConf.SMTPConfig } // GetACMEConfig returns the ACME configuration func GetACMEConfig() acme.Configuration { return globalConf.ACME } // HasServicesToStart returns true if the config defines at least a service to start. // Supported services are SFTP, FTP and WebDAV func HasServicesToStart() bool { if globalConf.SFTPD.ShouldBind() { return true } if globalConf.FTPD.ShouldBind() { return true } if globalConf.WebDAVD.ShouldBind() { return true } if globalConf.HTTPDConfig.ShouldBind() { return true } return false } func getRedactedPassword(value string) string { if value == "" { return value } return "[redacted]" } func getRedactedGlobalConf() globalConfig { conf := globalConf conf.Common.Actions.Hook = util.GetRedactedURL(conf.Common.Actions.Hook) conf.Common.StartupHook = util.GetRedactedURL(conf.Common.StartupHook) conf.Common.PostConnectHook = util.GetRedactedURL(conf.Common.PostConnectHook) conf.Common.PostDisconnectHook = util.GetRedactedURL(conf.Common.PostDisconnectHook) conf.SFTPD.KeyboardInteractiveHook = util.GetRedactedURL(conf.SFTPD.KeyboardInteractiveHook) conf.HTTPDConfig.SigningPassphrase = getRedactedPassword(conf.HTTPDConfig.SigningPassphrase) conf.HTTPDConfig.Setup.InstallationCode = getRedactedPassword(conf.HTTPDConfig.Setup.InstallationCode) conf.ProviderConf.Password = getRedactedPassword(conf.ProviderConf.Password) conf.ProviderConf.Actions.Hook = util.GetRedactedURL(conf.ProviderConf.Actions.Hook) conf.ProviderConf.ExternalAuthHook = util.GetRedactedURL(conf.ProviderConf.ExternalAuthHook) conf.ProviderConf.PreLoginHook = util.GetRedactedURL(conf.ProviderConf.PreLoginHook) conf.ProviderConf.PostLoginHook = util.GetRedactedURL(conf.ProviderConf.PostLoginHook) conf.ProviderConf.CheckPasswordHook = util.GetRedactedURL(conf.ProviderConf.CheckPasswordHook) conf.SMTPConfig.Password = getRedactedPassword(conf.SMTPConfig.Password) conf.HTTPDConfig.Bindings = nil for _, binding := range globalConf.HTTPDConfig.Bindings { binding.OIDC.ClientID = getRedactedPassword(binding.OIDC.ClientID) binding.OIDC.ClientSecret = getRedactedPassword(binding.OIDC.ClientSecret) conf.HTTPDConfig.Bindings = append(conf.HTTPDConfig.Bindings, binding) } conf.KMSConfig.Secrets.MasterKeyString = getRedactedPassword(conf.KMSConfig.Secrets.MasterKeyString) conf.PluginsConfig = nil for _, plugin := range globalConf.PluginsConfig { var args []string for _, arg := range plugin.Args { args = append(args, getRedactedPassword(arg)) } plugin.Args = args conf.PluginsConfig = append(conf.PluginsConfig, plugin) } return conf } func setConfigFile(configDir, configFile string) { if configFile == "" { return } if !filepath.IsAbs(configFile) && util.IsFileInputValid(configFile) { configFile = filepath.Join(configDir, configFile) } viper.SetConfigFile(configFile) } // readEnvFiles reads files inside the "env.d" directory relative to configDir // and then export the valid variables into environment variables if they do // not exist func readEnvFiles(configDir string) { envd := filepath.Join(configDir, "env.d") entries, err := os.ReadDir(envd) if err != nil { logger.Info(logSender, "", "unable to read env files from %q: %v", envd, err) return } for _, entry := range entries { info, err := entry.Info() if err == nil && info.Mode().IsRegular() { envFile := filepath.Join(envd, entry.Name()) if info.Size() > envFileMaxSize { logger.Info(logSender, "", "env file %q too big: %s, skipping", entry.Name(), util.ByteCountIEC(info.Size())) continue } err = gotenv.Load(envFile) if err != nil { logger.Error(logSender, "", "unable to load env vars from file %q, err: %v", envFile, err) } else { logger.Info(logSender, "", "set env vars from file %q", envFile) } } } } func checkOverrideDefaultSettings() { // for slices we need to set the defaults to nil if the key is set in the config file, // otherwise the values are merged and not replaced as expected rateLimiters := viper.Get("common.rate_limiters") if val, ok := rateLimiters.([]any); ok { if len(val) > 0 { if rl, ok := val[0].(map[string]any); ok { if _, ok := rl["protocols"]; ok { globalConf.Common.RateLimitersConfig[0].Protocols = nil } } } } httpdBindings := viper.Get("httpd.bindings") if val, ok := httpdBindings.([]any); ok { if len(val) > 0 { if binding, ok := val[0].(map[string]any); ok { if val, ok := binding["oidc"]; ok { if oidc, ok := val.(map[string]any); ok { if _, ok := oidc["scopes"]; ok { globalConf.HTTPDConfig.Bindings[0].OIDC.Scopes = nil } } } } } } if slices.Contains(viper.AllKeys(), "mfa.totp") { globalConf.MFAConfig.TOTP = nil } } // LoadConfig loads the configuration // configDir will be added to the configuration search paths. // The search path contains by default the current directory and on linux it contains // $HOME/.config/sftpgo and /etc/sftpgo too. // configFile is an absolute or relative path (to the config dir) to the configuration file. func LoadConfig(configDir, configFile string) error { var err error readEnvFiles(configDir) viper.AddConfigPath(configDir) setViperAdditionalConfigPaths() viper.AddConfigPath(".") setConfigFile(configDir, configFile) if err = viper.ReadInConfig(); err != nil { // if the user specify a configuration file we get os.ErrNotExist. // viper.ConfigFileNotFoundError is returned if viper is unable // to find sftpgo.{json,yaml, etc..} in any of the search paths if errors.As(err, &viper.ConfigFileNotFoundError{}) { logger.Debug(logSender, "", "no configuration file found") } else { logger.Warn(logSender, "", "error loading configuration file: %v", err) logger.WarnToConsole("error loading configuration file: %v", err) return err } } checkOverrideDefaultSettings() err = viper.Unmarshal(&globalConf) if err != nil { logger.Warn(logSender, "", "error parsing configuration file: %v", err) logger.WarnToConsole("error parsing configuration file: %v", err) return err } // viper only supports slice of strings from env vars, so we use our custom method loadBindingsFromEnv() loadWebDAVCacheMappingsFromEnv() resetInvalidConfigs() logger.Debug(logSender, "", "config file used: '%q', config loaded: %+v", viper.ConfigFileUsed(), getRedactedGlobalConf()) return nil } func isProxyProtocolValid() bool { return globalConf.Common.ProxyProtocol >= 0 && globalConf.Common.ProxyProtocol <= 2 } func isExternalAuthScopeValid() bool { return globalConf.ProviderConf.ExternalAuthScope >= 0 && globalConf.ProviderConf.ExternalAuthScope <= 15 } func resetInvalidConfigs() { if strings.TrimSpace(globalConf.HTTPDConfig.Setup.InstallationCodeHint) == "" { globalConf.HTTPDConfig.Setup.InstallationCodeHint = defaultInstallCodeHint } if globalConf.ProviderConf.UsersBaseDir != "" && !util.IsFileInputValid(globalConf.ProviderConf.UsersBaseDir) { warn := fmt.Sprintf("invalid users base dir %q will be ignored", globalConf.ProviderConf.UsersBaseDir) globalConf.ProviderConf.UsersBaseDir = "" logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) logger.WarnToConsole("Non-fatal configuration error: %v", warn) } if !isProxyProtocolValid() { warn := fmt.Sprintf("invalid proxy_protocol 0, 1 and 2 are supported, configured: %v reset proxy_protocol to 0", globalConf.Common.ProxyProtocol) globalConf.Common.ProxyProtocol = 0 logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) logger.WarnToConsole("Non-fatal configuration error: %v", warn) } if !isExternalAuthScopeValid() { warn := fmt.Sprintf("invalid external_auth_scope: %v reset to 0", globalConf.ProviderConf.ExternalAuthScope) globalConf.ProviderConf.ExternalAuthScope = 0 logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) logger.WarnToConsole("Non-fatal configuration error: %v", warn) } if globalConf.Common.DefenderConfig.Enabled && globalConf.Common.DefenderConfig.Driver == common.DefenderDriverProvider { if !globalConf.ProviderConf.IsDefenderSupported() { warn := fmt.Sprintf("provider based defender is not supported with data provider %q, "+ "the memory defender implementation will be used. If you want to use the provider defender "+ "implementation please switch to a shared/distributed data provider", globalConf.ProviderConf.Driver) globalConf.Common.DefenderConfig.Driver = common.DefenderDriverMemory logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) logger.WarnToConsole("Non-fatal configuration error: %v", warn) } } if globalConf.Common.RenameMode < 0 || globalConf.Common.RenameMode > 1 { warn := fmt.Sprintf("invalid rename mode %d, reset to 0", globalConf.Common.RenameMode) globalConf.Common.RenameMode = 0 logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) logger.WarnToConsole("Non-fatal configuration error: %v", warn) } } func loadBindingsFromEnv() { for idx := 0; idx < 10; idx++ { getTOTPFromEnv(idx) getRateLimitersFromEnv(idx) getPluginsFromEnv(idx) getSFTPDBindindFromEnv(idx) getFTPDBindingFromEnv(idx) getWebDAVDBindingFromEnv(idx) getHTTPDBindingFromEnv(idx) getHTTPClientCertificatesFromEnv(idx) getHTTPClientHeadersFromEnv(idx) getCommandConfigsFromEnv(idx) } } func getTOTPFromEnv(idx int) { totpConfig := defaultTOTP if len(globalConf.MFAConfig.TOTP) > idx { totpConfig = globalConf.MFAConfig.TOTP[idx] } isSet := false name, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_MFA__TOTP__%v__NAME", idx)) if ok { totpConfig.Name = name isSet = true } issuer, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_MFA__TOTP__%v__ISSUER", idx)) if ok { totpConfig.Issuer = issuer isSet = true } algo, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_MFA__TOTP__%v__ALGO", idx)) if ok { totpConfig.Algo = algo isSet = true } if isSet { if len(globalConf.MFAConfig.TOTP) > idx { globalConf.MFAConfig.TOTP[idx] = totpConfig } else { globalConf.MFAConfig.TOTP = append(globalConf.MFAConfig.TOTP, totpConfig) } } } func getRateLimitersFromEnv(idx int) { rtlConfig := defaultRateLimiter if len(globalConf.Common.RateLimitersConfig) > idx { rtlConfig = globalConf.Common.RateLimitersConfig[idx] } isSet := false average, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__AVERAGE", idx), 64) if ok { rtlConfig.Average = average isSet = true } period, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__PERIOD", idx), 64) if ok { rtlConfig.Period = period isSet = true } burst, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__BURST", idx), 32) if ok { rtlConfig.Burst = int(burst) isSet = true } rtlType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__TYPE", idx), 32) if ok { rtlConfig.Type = int(rtlType) isSet = true } protocols, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__PROTOCOLS", idx)) if ok { rtlConfig.Protocols = protocols isSet = true } generateEvents, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__GENERATE_DEFENDER_EVENTS", idx)) if ok { rtlConfig.GenerateDefenderEvents = generateEvents isSet = true } softLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_SOFT_LIMIT", idx), 32) if ok { rtlConfig.EntriesSoftLimit = int(softLimit) isSet = true } hardLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_HARD_LIMIT", idx), 32) if ok { rtlConfig.EntriesHardLimit = int(hardLimit) isSet = true } if isSet { if len(globalConf.Common.RateLimitersConfig) > idx { globalConf.Common.RateLimitersConfig[idx] = rtlConfig } else { globalConf.Common.RateLimitersConfig = append(globalConf.Common.RateLimitersConfig, rtlConfig) } } } func getKMSPluginFromEnv(idx int, pluginConfig *plugin.Config) bool { isSet := false kmsScheme, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__KMS_OPTIONS__SCHEME", idx)) if ok { pluginConfig.KMSOptions.Scheme = kmsScheme isSet = true } kmsEncStatus, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__KMS_OPTIONS__ENCRYPTED_STATUS", idx)) if ok { pluginConfig.KMSOptions.EncryptedStatus = kmsEncStatus isSet = true } return isSet } func getAuthPluginFromEnv(idx int, pluginConfig *plugin.Config) bool { isSet := false authScope, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__AUTH_OPTIONS__SCOPE", idx), 32) if ok { pluginConfig.AuthOptions.Scope = int(authScope) isSet = true } return isSet } func getNotifierPluginFromEnv(idx int, pluginConfig *plugin.Config) bool { isSet := false notifierFsEvents, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__FS_EVENTS", idx)) if ok { pluginConfig.NotifierOptions.FsEvents = notifierFsEvents isSet = true } notifierProviderEvents, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__PROVIDER_EVENTS", idx)) if ok { pluginConfig.NotifierOptions.ProviderEvents = notifierProviderEvents isSet = true } notifierProviderObjects, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__PROVIDER_OBJECTS", idx)) if ok { pluginConfig.NotifierOptions.ProviderObjects = notifierProviderObjects isSet = true } notifierLogEventsString, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__LOG_EVENTS", idx)) if ok { var notifierLogEvents []int for _, e := range notifierLogEventsString { ev, err := strconv.Atoi(e) if err == nil { notifierLogEvents = append(notifierLogEvents, ev) } } if len(notifierLogEvents) > 0 { pluginConfig.NotifierOptions.LogEvents = notifierLogEvents isSet = true } } notifierRetryMaxTime, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__RETRY_MAX_TIME", idx), 32) if ok { pluginConfig.NotifierOptions.RetryMaxTime = int(notifierRetryMaxTime) isSet = true } notifierRetryQueueMaxSize, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__RETRY_QUEUE_MAX_SIZE", idx), 32) if ok { pluginConfig.NotifierOptions.RetryQueueMaxSize = int(notifierRetryQueueMaxSize) isSet = true } return isSet } func getPluginsFromEnv(idx int) { pluginConfig := plugin.Config{} if len(globalConf.PluginsConfig) > idx { pluginConfig = globalConf.PluginsConfig[idx] } isSet := false pluginType, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__TYPE", idx)) if ok { pluginConfig.Type = pluginType isSet = true } if getNotifierPluginFromEnv(idx, &pluginConfig) { isSet = true } if getKMSPluginFromEnv(idx, &pluginConfig) { isSet = true } if getAuthPluginFromEnv(idx, &pluginConfig) { isSet = true } cmd, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__CMD", idx)) if ok { pluginConfig.Cmd = cmd isSet = true } cmdArgs, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__ARGS", idx)) if ok { pluginConfig.Args = cmdArgs isSet = true } pluginHash, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__SHA256SUM", idx)) if ok { pluginConfig.SHA256Sum = pluginHash isSet = true } autoMTLS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__AUTO_MTLS", idx)) if ok { pluginConfig.AutoMTLS = autoMTLS isSet = true } envPrefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__ENV_PREFIX", idx)) if ok { pluginConfig.EnvPrefix = envPrefix isSet = true } envVars, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__ENV_VARS", idx)) if ok { pluginConfig.EnvVars = envVars isSet = true } if isSet { if len(globalConf.PluginsConfig) > idx { globalConf.PluginsConfig[idx] = pluginConfig } else { globalConf.PluginsConfig = append(globalConf.PluginsConfig, pluginConfig) } } } func getSFTPDBindindFromEnv(idx int) { binding := defaultSFTPDBinding if len(globalConf.SFTPD.Bindings) > idx { binding = globalConf.SFTPD.Bindings[idx] } isSet := false port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__PORT", idx), 32) if ok { binding.Port = int(port) isSet = true } address, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__ADDRESS", idx)) if ok { binding.Address = address isSet = true } applyProxyConfig, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__APPLY_PROXY_CONFIG", idx)) if ok { binding.ApplyProxyConfig = applyProxyConfig isSet = true } if isSet { if len(globalConf.SFTPD.Bindings) > idx { globalConf.SFTPD.Bindings[idx] = binding } else { globalConf.SFTPD.Bindings = append(globalConf.SFTPD.Bindings, binding) } } } func getFTPDPassiveIPOverridesFromEnv(idx int) []ftpd.PassiveIPOverride { var overrides []ftpd.PassiveIPOverride if len(globalConf.FTPD.Bindings) > idx { overrides = globalConf.FTPD.Bindings[idx].PassiveIPOverrides } for subIdx := 0; subIdx < 10; subIdx++ { var override ftpd.PassiveIPOverride var replace bool if len(globalConf.FTPD.Bindings) > idx && len(globalConf.FTPD.Bindings[idx].PassiveIPOverrides) > subIdx { override = globalConf.FTPD.Bindings[idx].PassiveIPOverrides[subIdx] replace = true } ip, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_IP_OVERRIDES__%v__IP", idx, subIdx)) if ok { override.IP = ip } networks, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_IP_OVERRIDES__%v__NETWORKS", idx, subIdx)) if ok { override.Networks = networks } if len(override.Networks) > 0 { if replace { overrides[subIdx] = override } else { overrides = append(overrides, override) } } } return overrides } func getDefaultFTPDBinding(idx int) ftpd.Binding { binding := defaultFTPDBinding if len(globalConf.FTPD.Bindings) > idx { binding = globalConf.FTPD.Bindings[idx] } return binding } func getFTPDBindingSecurityFromEnv(idx int, binding *ftpd.Binding) bool { isSet := false certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CERTIFICATE_FILE", idx)) if ok { binding.CertificateFile = certificateFile isSet = true } certificateKeyFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CERTIFICATE_KEY_FILE", idx)) if ok { binding.CertificateKeyFile = certificateKeyFile isSet = true } tlsMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_MODE", idx), 32) if ok { binding.TLSMode = int(tlsMode) isSet = true } tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32) if ok { binding.MinTLSVersion = int(tlsVer) isSet = true } tlsCiphers, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_CIPHER_SUITES", idx)) if ok { binding.TLSCipherSuites = tlsCiphers isSet = true } clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32) if ok { binding.ClientAuthType = int(clientAuthType) isSet = true } pasvSecurity, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_CONNECTIONS_SECURITY", idx), 32) if ok { binding.PassiveConnectionsSecurity = int(pasvSecurity) isSet = true } activeSecurity, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__ACTIVE_CONNECTIONS_SECURITY", idx), 32) if ok { binding.ActiveConnectionsSecurity = int(activeSecurity) isSet = true } return isSet } func getFTPDBindingFromEnv(idx int) { binding := getDefaultFTPDBinding(idx) isSet := false port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PORT", idx), 32) if ok { binding.Port = int(port) isSet = true } address, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__ADDRESS", idx)) if ok { binding.Address = address isSet = true } applyProxyConfig, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__APPLY_PROXY_CONFIG", idx)) if ok { binding.ApplyProxyConfig = applyProxyConfig isSet = true } passiveIP, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__FORCE_PASSIVE_IP", idx)) if ok { binding.ForcePassiveIP = passiveIP isSet = true } passiveIPOverrides := getFTPDPassiveIPOverridesFromEnv(idx) if len(passiveIPOverrides) > 0 { binding.PassiveIPOverrides = passiveIPOverrides isSet = true } passiveHost, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_HOST", idx)) if ok { binding.PassiveHost = passiveHost isSet = true } debug, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__DEBUG", idx)) if ok { binding.Debug = debug isSet = true } if getFTPDBindingSecurityFromEnv(idx, &binding) { isSet = true } applyFTPDBindingFromEnv(idx, isSet, binding) } func applyFTPDBindingFromEnv(idx int, isSet bool, binding ftpd.Binding) { if isSet { if len(globalConf.FTPD.Bindings) > idx { globalConf.FTPD.Bindings[idx] = binding } else { globalConf.FTPD.Bindings = append(globalConf.FTPD.Bindings, binding) } } } func getWebDAVBindingHTTPSConfigsFromEnv(idx int, binding *webdavd.Binding) bool { isSet := false enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__ENABLE_HTTPS", idx)) if ok { binding.EnableHTTPS = enableHTTPS isSet = true } certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CERTIFICATE_FILE", idx)) if ok { binding.CertificateFile = certificateFile isSet = true } certificateKeyFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CERTIFICATE_KEY_FILE", idx)) if ok { binding.CertificateKeyFile = certificateKeyFile isSet = true } tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32) if ok { binding.MinTLSVersion = int(tlsVer) isSet = true } clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32) if ok { binding.ClientAuthType = int(clientAuthType) isSet = true } tlsCiphers, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__TLS_CIPHER_SUITES", idx)) if ok { binding.TLSCipherSuites = tlsCiphers isSet = true } protocols, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%d__TLS_PROTOCOLS", idx)) if ok { binding.Protocols = protocols isSet = true } return isSet } func getWebDAVDBindingProxyConfigsFromEnv(idx int, binding *webdavd.Binding) bool { isSet := false proxyMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_MODE", idx), 32) if ok { binding.ProxyMode = int(proxyMode) isSet = true } proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_ALLOWED", idx)) if ok { binding.ProxyAllowed = proxyAllowed isSet = true } clientIPProxyHeader, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_IP_PROXY_HEADER", idx)) if ok { binding.ClientIPProxyHeader = clientIPProxyHeader isSet = true } clientIPHeaderDepth, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_IP_HEADER_DEPTH", idx), 32) if ok { binding.ClientIPHeaderDepth = int(clientIPHeaderDepth) isSet = true } return isSet } func loadWebDAVCacheMappingsFromEnv() []webdavd.CustomMimeMapping { for idx := 0; idx < 30; idx++ { ext, extOK := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__%d__EXT", idx)) mime, mimeOK := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__%d__MIME", idx)) if extOK && mimeOK { if len(globalConf.WebDAVD.Cache.MimeTypes.CustomMappings) > idx { globalConf.WebDAVD.Cache.MimeTypes.CustomMappings[idx].Ext = ext globalConf.WebDAVD.Cache.MimeTypes.CustomMappings[idx].Mime = mime } else { globalConf.WebDAVD.Cache.MimeTypes.CustomMappings = append(globalConf.WebDAVD.Cache.MimeTypes.CustomMappings, webdavd.CustomMimeMapping{ Ext: ext, Mime: mime, }) } } } return globalConf.WebDAVD.Cache.MimeTypes.CustomMappings } func getWebDAVDBindingFromEnv(idx int) { binding := defaultWebDAVDBinding if len(globalConf.WebDAVD.Bindings) > idx { binding = globalConf.WebDAVD.Bindings[idx] } isSet := false port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PORT", idx), 32) if ok { binding.Port = int(port) isSet = true } address, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__ADDRESS", idx)) if ok { binding.Address = address isSet = true } if getWebDAVBindingHTTPSConfigsFromEnv(idx, &binding) { isSet = true } prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx)) if ok { binding.Prefix = prefix isSet = true } if getWebDAVDBindingProxyConfigsFromEnv(idx, &binding) { isSet = true } disableWWWAuth, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__DISABLE_WWW_AUTH_HEADER", idx)) if ok { binding.DisableWWWAuthHeader = disableWWWAuth isSet = true } if isSet { if len(globalConf.WebDAVD.Bindings) > idx { globalConf.WebDAVD.Bindings[idx] = binding } else { globalConf.WebDAVD.Bindings = append(globalConf.WebDAVD.Bindings, binding) } } } func getHTTPDSecurityProxyHeadersFromEnv(idx int) []httpd.HTTPSProxyHeader { var httpsProxyHeaders []httpd.HTTPSProxyHeader if len(globalConf.HTTPDConfig.Bindings) > idx { httpsProxyHeaders = globalConf.HTTPDConfig.Bindings[idx].Security.HTTPSProxyHeaders } for subIdx := 0; subIdx < 10; subIdx++ { var httpsProxyHeader httpd.HTTPSProxyHeader var replace bool if len(globalConf.HTTPDConfig.Bindings) > idx && len(globalConf.HTTPDConfig.Bindings[idx].Security.HTTPSProxyHeaders) > subIdx { httpsProxyHeader = httpsProxyHeaders[subIdx] replace = true } proxyKey, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HTTPS_PROXY_HEADERS__%v__KEY", idx, subIdx)) if ok { httpsProxyHeader.Key = proxyKey } proxyVal, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HTTPS_PROXY_HEADERS__%v__VALUE", idx, subIdx)) if ok { httpsProxyHeader.Value = proxyVal } if httpsProxyHeader.Key != "" && httpsProxyHeader.Value != "" { if replace { httpsProxyHeaders[subIdx] = httpsProxyHeader } else { httpsProxyHeaders = append(httpsProxyHeaders, httpsProxyHeader) } } } return httpsProxyHeaders } func getHTTPDSecurityConfFromEnv(idx int) (httpd.SecurityConf, bool) { //nolint:gocyclo result := defaultHTTPDBinding.Security if len(globalConf.HTTPDConfig.Bindings) > idx { result = globalConf.HTTPDConfig.Bindings[idx].Security } isSet := false enabled, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__ENABLED", idx)) if ok { result.Enabled = enabled isSet = true } allowedHosts, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__ALLOWED_HOSTS", idx)) if ok { result.AllowedHosts = allowedHosts isSet = true } allowedHostsAreRegex, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__ALLOWED_HOSTS_ARE_REGEX", idx)) if ok { result.AllowedHostsAreRegex = allowedHostsAreRegex isSet = true } hostsProxyHeaders, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HOSTS_PROXY_HEADERS", idx)) if ok { result.HostsProxyHeaders = hostsProxyHeaders isSet = true } httpsRedirect, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HTTPS_REDIRECT", idx)) if ok { result.HTTPSRedirect = httpsRedirect isSet = true } httpsHost, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HTTPS_HOST", idx)) if ok { result.HTTPSHost = httpsHost isSet = true } httpsProxyHeaders := getHTTPDSecurityProxyHeadersFromEnv(idx) if len(httpsProxyHeaders) > 0 { result.HTTPSProxyHeaders = httpsProxyHeaders isSet = true } stsSeconds, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__STS_SECONDS", idx), 64) if ok { result.STSSeconds = stsSeconds isSet = true } stsIncludeSubDomains, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__STS_INCLUDE_SUBDOMAINS", idx)) if ok { result.STSIncludeSubdomains = stsIncludeSubDomains isSet = true } stsPreload, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__STS_PRELOAD", idx)) if ok { result.STSPreload = stsPreload isSet = true } contentTypeNosniff, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CONTENT_TYPE_NOSNIFF", idx)) if ok { result.ContentTypeNosniff = contentTypeNosniff isSet = true } contentSecurityPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CONTENT_SECURITY_POLICY", idx)) if ok { result.ContentSecurityPolicy = contentSecurityPolicy isSet = true } permissionsPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__PERMISSIONS_POLICY", idx)) if ok { result.PermissionsPolicy = permissionsPolicy isSet = true } crossOriginOpenerPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CROSS_ORIGIN_OPENER_POLICY", idx)) if ok { result.CrossOriginOpenerPolicy = crossOriginOpenerPolicy isSet = true } crossOriginResourcePolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CROSS_ORIGIN_RESOURCE_POLICY", idx)) if ok { result.CrossOriginResourcePolicy = crossOriginResourcePolicy isSet = true } crossOriginEmbedderPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CROSS_ORIGIN_EMBEDDER_POLICY", idx)) if ok { result.CrossOriginEmbedderPolicy = crossOriginEmbedderPolicy isSet = true } referredPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__REFERRER_POLICY", idx)) if ok { result.ReferrerPolicy = referredPolicy isSet = true } cacheControl, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CACHE_CONTROL", idx)) if ok { result.CacheControl = cacheControl isSet = true } return result, isSet } func getHTTPDOIDCFromEnv(idx int) (httpd.OIDC, bool) { result := defaultHTTPDBinding.OIDC if len(globalConf.HTTPDConfig.Bindings) > idx { result = globalConf.HTTPDConfig.Bindings[idx].OIDC } isSet := false clientID, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CLIENT_ID", idx)) if ok { result.ClientID = clientID isSet = true } clientSecret, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CLIENT_SECRET", idx)) if ok { result.ClientSecret = clientSecret isSet = true } clientSecretFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CLIENT_SECRET_FILE", idx)) if ok { result.ClientSecretFile = clientSecretFile isSet = true } configURL, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CONFIG_URL", idx)) if ok { result.ConfigURL = configURL isSet = true } redirectBaseURL, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__REDIRECT_BASE_URL", idx)) if ok { result.RedirectBaseURL = redirectBaseURL isSet = true } usernameField, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__USERNAME_FIELD", idx)) if ok { result.UsernameField = usernameField isSet = true } scopes, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__SCOPES", idx)) if ok { result.Scopes = scopes isSet = true } roleField, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__ROLE_FIELD", idx)) if ok { result.RoleField = roleField isSet = true } implicitRoles, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__IMPLICIT_ROLES", idx)) if ok { result.ImplicitRoles = implicitRoles isSet = true } customFields, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CUSTOM_FIELDS", idx)) if ok { result.CustomFields = customFields isSet = true } skipSignatureCheck, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__INSECURE_SKIP_SIGNATURE_CHECK", idx)) if ok { result.InsecureSkipSignatureCheck = skipSignatureCheck isSet = true } debug, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__DEBUG", idx)) if ok { result.Debug = debug isSet = true } return result, isSet } func getHTTPDUIBrandingFromEnv(prefix string, branding httpd.UIBranding) (httpd.UIBranding, bool) { isSet := false name, ok := os.LookupEnv(fmt.Sprintf("%s__NAME", prefix)) if ok { branding.Name = name isSet = true } shortName, ok := os.LookupEnv(fmt.Sprintf("%s__SHORT_NAME", prefix)) if ok { branding.ShortName = shortName isSet = true } faviconPath, ok := os.LookupEnv(fmt.Sprintf("%s__FAVICON_PATH", prefix)) if ok { branding.FaviconPath = faviconPath isSet = true } logoPath, ok := os.LookupEnv(fmt.Sprintf("%s__LOGO_PATH", prefix)) if ok { branding.LogoPath = logoPath isSet = true } disclaimerName, ok := os.LookupEnv(fmt.Sprintf("%s__DISCLAIMER_NAME", prefix)) if ok { branding.DisclaimerName = disclaimerName isSet = true } disclaimerPath, ok := os.LookupEnv(fmt.Sprintf("%s__DISCLAIMER_PATH", prefix)) if ok { branding.DisclaimerPath = disclaimerPath isSet = true } defaultCSSPath, ok := lookupStringListFromEnv(fmt.Sprintf("%s__DEFAULT_CSS", prefix)) if ok { branding.DefaultCSS = defaultCSSPath isSet = true } extraCSS, ok := lookupStringListFromEnv(fmt.Sprintf("%s__EXTRA_CSS", prefix)) if ok { branding.ExtraCSS = extraCSS isSet = true } return branding, isSet } func getHTTPDBrandingFromEnv(idx int) (httpd.Branding, bool) { result := defaultHTTPDBinding.Branding if len(globalConf.HTTPDConfig.Bindings) > idx { result = globalConf.HTTPDConfig.Bindings[idx].Branding } isSet := false webAdmin, ok := getHTTPDUIBrandingFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__BRANDING__WEB_ADMIN", idx), result.WebAdmin) if ok { result.WebAdmin = webAdmin isSet = true } webClient, ok := getHTTPDUIBrandingFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__BRANDING__WEB_CLIENT", idx), result.WebClient) if ok { result.WebClient = webClient isSet = true } return result, isSet } func getDefaultHTTPBinding(idx int) httpd.Binding { binding := defaultHTTPDBinding if len(globalConf.HTTPDConfig.Bindings) > idx { binding = globalConf.HTTPDConfig.Bindings[idx] } return binding } func getHTTPDNestedObjectsFromEnv(idx int, binding *httpd.Binding) bool { isSet := false oidc, ok := getHTTPDOIDCFromEnv(idx) if ok { binding.OIDC = oidc isSet = true } securityConf, ok := getHTTPDSecurityConfFromEnv(idx) if ok { binding.Security = securityConf isSet = true } brandingConf, ok := getHTTPDBrandingFromEnv(idx) if ok { binding.Branding = brandingConf isSet = true } return isSet } func getHTTPDBindingProxyConfigsFromEnv(idx int, binding *httpd.Binding) bool { isSet := false proxyMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_MODE", idx), 32) if ok { binding.ProxyMode = int(proxyMode) isSet = true } proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_ALLOWED", idx)) if ok { binding.ProxyAllowed = proxyAllowed isSet = true } clientIPProxyHeader, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_IP_PROXY_HEADER", idx)) if ok { binding.ClientIPProxyHeader = clientIPProxyHeader isSet = true } clientIPHeaderDepth, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_IP_HEADER_DEPTH", idx), 32) if ok { binding.ClientIPHeaderDepth = int(clientIPHeaderDepth) isSet = true } return isSet } func getHTTPDBindingFromEnv(idx int) { //nolint:gocyclo binding := getDefaultHTTPBinding(idx) isSet := false port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PORT", idx), 32) if ok { binding.Port = int(port) isSet = true } address, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ADDRESS", idx)) if ok { binding.Address = address isSet = true } certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CERTIFICATE_FILE", idx)) if ok { binding.CertificateFile = certificateFile isSet = true } certificateKeyFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CERTIFICATE_KEY_FILE", idx)) if ok { binding.CertificateKeyFile = certificateKeyFile isSet = true } enableWebAdmin, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_WEB_ADMIN", idx)) if ok { binding.EnableWebAdmin = enableWebAdmin isSet = true } enableWebClient, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_WEB_CLIENT", idx)) if ok { binding.EnableWebClient = enableWebClient isSet = true } enableRESTAPI, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_REST_API", idx)) if ok { binding.EnableRESTAPI = enableRESTAPI isSet = true } enabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLED_LOGIN_METHODS", idx), 32) if ok { binding.EnabledLoginMethods = int(enabledLoginMethods) isSet = true } disabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__DISABLED_LOGIN_METHODS", idx), 32) if ok { binding.DisabledLoginMethods = int(disabledLoginMethods) isSet = true } renderOpenAPI, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__RENDER_OPENAPI", idx)) if ok { binding.RenderOpenAPI = renderOpenAPI isSet = true } baseURL, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%d__BASE_URL", idx)) if ok { binding.BaseURL = baseURL isSet = true } languages, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%d__LANGUAGES", idx)) if ok { binding.Languages = languages isSet = true } enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_HTTPS", idx)) if ok { binding.EnableHTTPS = enableHTTPS isSet = true } tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32) if ok { binding.MinTLSVersion = int(tlsVer) isSet = true } clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32) if ok { binding.ClientAuthType = int(clientAuthType) isSet = true } tlsCiphers, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__TLS_CIPHER_SUITES", idx)) if ok { binding.TLSCipherSuites = tlsCiphers isSet = true } protocols, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%d__TLS_PROTOCOLS", idx)) if ok { binding.Protocols = protocols isSet = true } if getHTTPDBindingProxyConfigsFromEnv(idx, &binding) { isSet = true } hideLoginURL, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__HIDE_LOGIN_URL", idx), 32) if ok { binding.HideLoginURL = int(hideLoginURL) isSet = true } if getHTTPDNestedObjectsFromEnv(idx, &binding) { isSet = true } setHTTPDBinding(isSet, binding, idx) } func setHTTPDBinding(isSet bool, binding httpd.Binding, idx int) { if isSet { if len(globalConf.HTTPDConfig.Bindings) > idx { globalConf.HTTPDConfig.Bindings[idx] = binding } else { globalConf.HTTPDConfig.Bindings = append(globalConf.HTTPDConfig.Bindings, binding) } } } func getHTTPClientCertificatesFromEnv(idx int) { tlsCert := httpclient.TLSKeyPair{} if len(globalConf.HTTPConfig.Certificates) > idx { tlsCert = globalConf.HTTPConfig.Certificates[idx] } cert, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__CERTIFICATES__%v__CERT", idx)) if ok { tlsCert.Cert = cert } key, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__CERTIFICATES__%v__KEY", idx)) if ok { tlsCert.Key = key } if tlsCert.Cert != "" && tlsCert.Key != "" { if len(globalConf.HTTPConfig.Certificates) > idx { globalConf.HTTPConfig.Certificates[idx] = tlsCert } else { globalConf.HTTPConfig.Certificates = append(globalConf.HTTPConfig.Certificates, tlsCert) } } } func getHTTPClientHeadersFromEnv(idx int) { header := httpclient.Header{} if len(globalConf.HTTPConfig.Headers) > idx { header = globalConf.HTTPConfig.Headers[idx] } key, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__KEY", idx)) if ok { header.Key = key } value, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__VALUE", idx)) if ok { header.Value = value } url, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__URL", idx)) if ok { header.URL = url } if header.Key != "" && header.Value != "" { if len(globalConf.HTTPConfig.Headers) > idx { globalConf.HTTPConfig.Headers[idx] = header } else { globalConf.HTTPConfig.Headers = append(globalConf.HTTPConfig.Headers, header) } } } func getCommandConfigsFromEnv(idx int) { cfg := command.Command{} if len(globalConf.CommandConfig.Commands) > idx { cfg = globalConf.CommandConfig.Commands[idx] } path, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__PATH", idx)) if ok { cfg.Path = path } timeout, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__TIMEOUT", idx), 32) if ok { cfg.Timeout = int(timeout) } env, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__ENV", idx)) if ok { cfg.Env = env } args, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__ARGS", idx)) if ok { cfg.Args = args } if cfg.Path != "" { if len(globalConf.CommandConfig.Commands) > idx { globalConf.CommandConfig.Commands[idx] = cfg } else { globalConf.CommandConfig.Commands = append(globalConf.CommandConfig.Commands, cfg) } } } func setViperDefaults() { viper.SetDefault("common.idle_timeout", globalConf.Common.IdleTimeout) viper.SetDefault("common.upload_mode", globalConf.Common.UploadMode) viper.SetDefault("common.actions.execute_on", globalConf.Common.Actions.ExecuteOn) viper.SetDefault("common.actions.execute_sync", globalConf.Common.Actions.ExecuteSync) viper.SetDefault("common.actions.hook", globalConf.Common.Actions.Hook) viper.SetDefault("common.setstat_mode", globalConf.Common.SetstatMode) viper.SetDefault("common.rename_mode", globalConf.Common.RenameMode) viper.SetDefault("common.resume_max_size", globalConf.Common.ResumeMaxSize) viper.SetDefault("common.temp_path", globalConf.Common.TempPath) viper.SetDefault("common.proxy_protocol", globalConf.Common.ProxyProtocol) viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed) viper.SetDefault("common.proxy_skipped", globalConf.Common.ProxySkipped) viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook) viper.SetDefault("common.post_disconnect_hook", globalConf.Common.PostDisconnectHook) viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections) viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections) viper.SetDefault("common.allowlist_status", globalConf.Common.AllowListStatus) viper.SetDefault("common.allow_self_connections", globalConf.Common.AllowSelfConnections) viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled) viper.SetDefault("common.defender.driver", globalConf.Common.DefenderConfig.Driver) viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime) viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement) viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold) viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid) viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid) viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded) viper.SetDefault("common.defender.score_no_auth", globalConf.Common.DefenderConfig.ScoreNoAuth) viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime) viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit) viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit) viper.SetDefault("common.defender.login_delay.success", globalConf.Common.DefenderConfig.LoginDelay.Success) viper.SetDefault("common.defender.login_delay.password_failed", globalConf.Common.DefenderConfig.LoginDelay.PasswordFailed) viper.SetDefault("common.umask", globalConf.Common.Umask) viper.SetDefault("common.server_version", globalConf.Common.ServerVersion) viper.SetDefault("common.tz", globalConf.Common.TZ) viper.SetDefault("common.metadata.read", globalConf.Common.Metadata.Read) viper.SetDefault("common.event_manager.enabled_commands", globalConf.Common.EventManager.EnabledCommands) viper.SetDefault("acme.email", globalConf.ACME.Email) viper.SetDefault("acme.key_type", globalConf.ACME.KeyType) viper.SetDefault("acme.certs_path", globalConf.ACME.CertsPath) viper.SetDefault("acme.ca_endpoint", globalConf.ACME.CAEndpoint) viper.SetDefault("acme.domains", globalConf.ACME.Domains) viper.SetDefault("acme.renew_days", globalConf.ACME.RenewDays) viper.SetDefault("acme.http01_challenge.port", globalConf.ACME.HTTP01Challenge.Port) viper.SetDefault("acme.http01_challenge.webroot", globalConf.ACME.HTTP01Challenge.WebRoot) viper.SetDefault("acme.http01_challenge.proxy_header", globalConf.ACME.HTTP01Challenge.ProxyHeader) viper.SetDefault("acme.tls_alpn01_challenge.port", globalConf.ACME.TLSALPN01Challenge.Port) viper.SetDefault("sftpd.max_auth_tries", globalConf.SFTPD.MaxAuthTries) viper.SetDefault("sftpd.host_keys", globalConf.SFTPD.HostKeys) viper.SetDefault("sftpd.host_certificates", globalConf.SFTPD.HostCertificates) viper.SetDefault("sftpd.host_key_algorithms", globalConf.SFTPD.HostKeyAlgorithms) viper.SetDefault("sftpd.kex_algorithms", globalConf.SFTPD.KexAlgorithms) viper.SetDefault("sftpd.ciphers", globalConf.SFTPD.Ciphers) viper.SetDefault("sftpd.macs", globalConf.SFTPD.MACs) viper.SetDefault("sftpd.public_key_algorithms", globalConf.SFTPD.PublicKeyAlgorithms) viper.SetDefault("sftpd.trusted_user_ca_keys", globalConf.SFTPD.TrustedUserCAKeys) viper.SetDefault("sftpd.revoked_user_certs_file", globalConf.SFTPD.RevokedUserCertsFile) viper.SetDefault("sftpd.opkssh_path", globalConf.SFTPD.OPKSSHPath) viper.SetDefault("sftpd.opkssh_checksum", globalConf.SFTPD.OPKSSHChecksum) viper.SetDefault("sftpd.login_banner_file", globalConf.SFTPD.LoginBannerFile) viper.SetDefault("sftpd.enabled_ssh_commands", sftpd.GetDefaultSSHCommands()) viper.SetDefault("sftpd.keyboard_interactive_authentication", globalConf.SFTPD.KeyboardInteractiveAuthentication) viper.SetDefault("sftpd.keyboard_interactive_auth_hook", globalConf.SFTPD.KeyboardInteractiveHook) viper.SetDefault("sftpd.password_authentication", globalConf.SFTPD.PasswordAuthentication) viper.SetDefault("ftpd.banner_file", globalConf.FTPD.BannerFile) viper.SetDefault("ftpd.active_transfers_port_non_20", globalConf.FTPD.ActiveTransfersPortNon20) viper.SetDefault("ftpd.passive_port_range.start", globalConf.FTPD.PassivePortRange.Start) viper.SetDefault("ftpd.passive_port_range.end", globalConf.FTPD.PassivePortRange.End) viper.SetDefault("ftpd.disable_active_mode", globalConf.FTPD.DisableActiveMode) viper.SetDefault("ftpd.enable_site", globalConf.FTPD.EnableSite) viper.SetDefault("ftpd.hash_support", globalConf.FTPD.HASHSupport) viper.SetDefault("ftpd.combine_support", globalConf.FTPD.CombineSupport) viper.SetDefault("ftpd.certificate_file", globalConf.FTPD.CertificateFile) viper.SetDefault("ftpd.certificate_key_file", globalConf.FTPD.CertificateKeyFile) viper.SetDefault("ftpd.ca_certificates", globalConf.FTPD.CACertificates) viper.SetDefault("ftpd.ca_revocation_lists", globalConf.FTPD.CARevocationLists) viper.SetDefault("webdavd.certificate_file", globalConf.WebDAVD.CertificateFile) viper.SetDefault("webdavd.certificate_key_file", globalConf.WebDAVD.CertificateKeyFile) viper.SetDefault("webdavd.ca_certificates", globalConf.WebDAVD.CACertificates) viper.SetDefault("webdavd.ca_revocation_lists", globalConf.WebDAVD.CARevocationLists) viper.SetDefault("webdavd.cors.enabled", globalConf.WebDAVD.Cors.Enabled) viper.SetDefault("webdavd.cors.allowed_origins", globalConf.WebDAVD.Cors.AllowedOrigins) viper.SetDefault("webdavd.cors.allowed_methods", globalConf.WebDAVD.Cors.AllowedMethods) viper.SetDefault("webdavd.cors.allowed_headers", globalConf.WebDAVD.Cors.AllowedHeaders) viper.SetDefault("webdavd.cors.exposed_headers", globalConf.WebDAVD.Cors.ExposedHeaders) viper.SetDefault("webdavd.cors.allow_credentials", globalConf.WebDAVD.Cors.AllowCredentials) viper.SetDefault("webdavd.cors.options_passthrough", globalConf.WebDAVD.Cors.OptionsPassthrough) viper.SetDefault("webdavd.cors.options_success_status", globalConf.WebDAVD.Cors.OptionsSuccessStatus) viper.SetDefault("webdavd.cors.allow_private_network", globalConf.WebDAVD.Cors.AllowPrivateNetwork) viper.SetDefault("webdavd.cors.max_age", globalConf.WebDAVD.Cors.MaxAge) viper.SetDefault("webdavd.cache.users.expiration_time", globalConf.WebDAVD.Cache.Users.ExpirationTime) viper.SetDefault("webdavd.cache.users.max_size", globalConf.WebDAVD.Cache.Users.MaxSize) viper.SetDefault("webdavd.cache.mime_types.enabled", globalConf.WebDAVD.Cache.MimeTypes.Enabled) viper.SetDefault("webdavd.cache.mime_types.max_size", globalConf.WebDAVD.Cache.MimeTypes.MaxSize) viper.SetDefault("webdavd.cache.mime_types.custom_mappings", globalConf.WebDAVD.Cache.MimeTypes.CustomMappings) viper.SetDefault("data_provider.driver", globalConf.ProviderConf.Driver) viper.SetDefault("data_provider.name", globalConf.ProviderConf.Name) viper.SetDefault("data_provider.host", globalConf.ProviderConf.Host) viper.SetDefault("data_provider.port", globalConf.ProviderConf.Port) viper.SetDefault("data_provider.username", globalConf.ProviderConf.Username) viper.SetDefault("data_provider.password", globalConf.ProviderConf.Password) viper.SetDefault("data_provider.sslmode", globalConf.ProviderConf.SSLMode) viper.SetDefault("data_provider.disable_sni", globalConf.ProviderConf.DisableSNI) viper.SetDefault("data_provider.target_session_attrs", globalConf.ProviderConf.TargetSessionAttrs) viper.SetDefault("data_provider.root_cert", globalConf.ProviderConf.RootCert) viper.SetDefault("data_provider.client_cert", globalConf.ProviderConf.ClientCert) viper.SetDefault("data_provider.client_key", globalConf.ProviderConf.ClientKey) viper.SetDefault("data_provider.connection_string", globalConf.ProviderConf.ConnectionString) viper.SetDefault("data_provider.sql_tables_prefix", globalConf.ProviderConf.SQLTablesPrefix) viper.SetDefault("data_provider.track_quota", globalConf.ProviderConf.TrackQuota) viper.SetDefault("data_provider.pool_size", globalConf.ProviderConf.PoolSize) viper.SetDefault("data_provider.users_base_dir", globalConf.ProviderConf.UsersBaseDir) viper.SetDefault("data_provider.actions.execute_on", globalConf.ProviderConf.Actions.ExecuteOn) viper.SetDefault("data_provider.actions.execute_for", globalConf.ProviderConf.Actions.ExecuteFor) viper.SetDefault("data_provider.actions.hook", globalConf.ProviderConf.Actions.Hook) viper.SetDefault("data_provider.external_auth_hook", globalConf.ProviderConf.ExternalAuthHook) viper.SetDefault("data_provider.external_auth_scope", globalConf.ProviderConf.ExternalAuthScope) viper.SetDefault("data_provider.pre_login_hook", globalConf.ProviderConf.PreLoginHook) viper.SetDefault("data_provider.post_login_hook", globalConf.ProviderConf.PostLoginHook) viper.SetDefault("data_provider.post_login_scope", globalConf.ProviderConf.PostLoginScope) viper.SetDefault("data_provider.check_password_hook", globalConf.ProviderConf.CheckPasswordHook) viper.SetDefault("data_provider.check_password_scope", globalConf.ProviderConf.CheckPasswordScope) viper.SetDefault("data_provider.password_hashing.bcrypt_options.cost", globalConf.ProviderConf.PasswordHashing.BcryptOptions.Cost) viper.SetDefault("data_provider.password_hashing.argon2_options.memory", globalConf.ProviderConf.PasswordHashing.Argon2Options.Memory) viper.SetDefault("data_provider.password_hashing.argon2_options.iterations", globalConf.ProviderConf.PasswordHashing.Argon2Options.Iterations) viper.SetDefault("data_provider.password_hashing.argon2_options.parallelism", globalConf.ProviderConf.PasswordHashing.Argon2Options.Parallelism) viper.SetDefault("data_provider.password_hashing.algo", globalConf.ProviderConf.PasswordHashing.Algo) viper.SetDefault("data_provider.password_validation.admins.min_entropy", globalConf.ProviderConf.PasswordValidation.Admins.MinEntropy) viper.SetDefault("data_provider.password_validation.users.min_entropy", globalConf.ProviderConf.PasswordValidation.Users.MinEntropy) viper.SetDefault("data_provider.password_caching", globalConf.ProviderConf.PasswordCaching) viper.SetDefault("data_provider.update_mode", globalConf.ProviderConf.UpdateMode) viper.SetDefault("data_provider.delayed_quota_update", globalConf.ProviderConf.DelayedQuotaUpdate) viper.SetDefault("data_provider.create_default_admin", globalConf.ProviderConf.CreateDefaultAdmin) viper.SetDefault("data_provider.naming_rules", globalConf.ProviderConf.NamingRules) viper.SetDefault("data_provider.is_shared", globalConf.ProviderConf.IsShared) viper.SetDefault("data_provider.node.host", globalConf.ProviderConf.Node.Host) viper.SetDefault("data_provider.node.port", globalConf.ProviderConf.Node.Port) viper.SetDefault("data_provider.node.proto", globalConf.ProviderConf.Node.Proto) viper.SetDefault("data_provider.backups_path", globalConf.ProviderConf.BackupsPath) viper.SetDefault("httpd.templates_path", globalConf.HTTPDConfig.TemplatesPath) viper.SetDefault("httpd.static_files_path", globalConf.HTTPDConfig.StaticFilesPath) viper.SetDefault("httpd.openapi_path", globalConf.HTTPDConfig.OpenAPIPath) viper.SetDefault("httpd.web_root", globalConf.HTTPDConfig.WebRoot) viper.SetDefault("httpd.certificate_file", globalConf.HTTPDConfig.CertificateFile) viper.SetDefault("httpd.certificate_key_file", globalConf.HTTPDConfig.CertificateKeyFile) viper.SetDefault("httpd.ca_certificates", globalConf.HTTPDConfig.CACertificates) viper.SetDefault("httpd.ca_revocation_lists", globalConf.HTTPDConfig.CARevocationLists) viper.SetDefault("httpd.signing_passphrase", globalConf.HTTPDConfig.SigningPassphrase) viper.SetDefault("httpd.signing_passphrase_file", globalConf.HTTPDConfig.SigningPassphraseFile) viper.SetDefault("httpd.token_validation", globalConf.HTTPDConfig.TokenValidation) viper.SetDefault("httpd.cookie_lifetime", globalConf.HTTPDConfig.CookieLifetime) viper.SetDefault("httpd.share_cookie_lifetime", globalConf.HTTPDConfig.ShareCookieLifetime) viper.SetDefault("httpd.jwt_lifetime", globalConf.HTTPDConfig.JWTLifetime) viper.SetDefault("httpd.max_upload_file_size", globalConf.HTTPDConfig.MaxUploadFileSize) viper.SetDefault("httpd.cors.enabled", globalConf.HTTPDConfig.Cors.Enabled) viper.SetDefault("httpd.cors.allowed_origins", globalConf.HTTPDConfig.Cors.AllowedOrigins) viper.SetDefault("httpd.cors.allowed_methods", globalConf.HTTPDConfig.Cors.AllowedMethods) viper.SetDefault("httpd.cors.allowed_headers", globalConf.HTTPDConfig.Cors.AllowedHeaders) viper.SetDefault("httpd.cors.exposed_headers", globalConf.HTTPDConfig.Cors.ExposedHeaders) viper.SetDefault("httpd.cors.allow_credentials", globalConf.HTTPDConfig.Cors.AllowCredentials) viper.SetDefault("httpd.cors.max_age", globalConf.HTTPDConfig.Cors.MaxAge) viper.SetDefault("httpd.cors.options_passthrough", globalConf.HTTPDConfig.Cors.OptionsPassthrough) viper.SetDefault("httpd.cors.options_success_status", globalConf.HTTPDConfig.Cors.OptionsSuccessStatus) viper.SetDefault("httpd.cors.allow_private_network", globalConf.HTTPDConfig.Cors.AllowPrivateNetwork) viper.SetDefault("httpd.setup.installation_code", globalConf.HTTPDConfig.Setup.InstallationCode) viper.SetDefault("httpd.setup.installation_code_hint", globalConf.HTTPDConfig.Setup.InstallationCodeHint) viper.SetDefault("httpd.hide_support_link", globalConf.HTTPDConfig.HideSupportLink) viper.SetDefault("http.timeout", globalConf.HTTPConfig.Timeout) viper.SetDefault("http.retry_wait_min", globalConf.HTTPConfig.RetryWaitMin) viper.SetDefault("http.retry_wait_max", globalConf.HTTPConfig.RetryWaitMax) viper.SetDefault("http.retry_max", globalConf.HTTPConfig.RetryMax) viper.SetDefault("http.ca_certificates", globalConf.HTTPConfig.CACertificates) viper.SetDefault("http.skip_tls_verify", globalConf.HTTPConfig.SkipTLSVerify) viper.SetDefault("command.timeout", globalConf.CommandConfig.Timeout) viper.SetDefault("command.env", globalConf.CommandConfig.Env) viper.SetDefault("kms.secrets.url", globalConf.KMSConfig.Secrets.URL) viper.SetDefault("kms.secrets.master_key", globalConf.KMSConfig.Secrets.MasterKeyString) viper.SetDefault("kms.secrets.master_key_path", globalConf.KMSConfig.Secrets.MasterKeyPath) viper.SetDefault("telemetry.bind_port", globalConf.TelemetryConfig.BindPort) viper.SetDefault("telemetry.bind_address", globalConf.TelemetryConfig.BindAddress) viper.SetDefault("telemetry.enable_profiler", globalConf.TelemetryConfig.EnableProfiler) viper.SetDefault("telemetry.auth_user_file", globalConf.TelemetryConfig.AuthUserFile) viper.SetDefault("telemetry.certificate_file", globalConf.TelemetryConfig.CertificateFile) viper.SetDefault("telemetry.certificate_key_file", globalConf.TelemetryConfig.CertificateKeyFile) viper.SetDefault("telemetry.min_tls_version", globalConf.TelemetryConfig.MinTLSVersion) viper.SetDefault("telemetry.tls_cipher_suites", globalConf.TelemetryConfig.TLSCipherSuites) viper.SetDefault("telemetry.tls_protocols", globalConf.TelemetryConfig.Protocols) viper.SetDefault("smtp.host", globalConf.SMTPConfig.Host) viper.SetDefault("smtp.port", globalConf.SMTPConfig.Port) viper.SetDefault("smtp.from", globalConf.SMTPConfig.From) viper.SetDefault("smtp.user", globalConf.SMTPConfig.User) viper.SetDefault("smtp.password", globalConf.SMTPConfig.Password) viper.SetDefault("smtp.auth_type", globalConf.SMTPConfig.AuthType) viper.SetDefault("smtp.encryption", globalConf.SMTPConfig.Encryption) viper.SetDefault("smtp.domain", globalConf.SMTPConfig.Domain) viper.SetDefault("smtp.templates_path", globalConf.SMTPConfig.TemplatesPath) } func lookupBoolFromEnv(envName string) (bool, bool) { value, ok := os.LookupEnv(envName) if ok { converted, err := strconv.ParseBool(strings.TrimSpace(value)) if err == nil { return converted, ok } } return false, false } func lookupIntFromEnv(envName string, bitSize int) (int64, bool) { value, ok := os.LookupEnv(envName) if ok { converted, err := strconv.ParseInt(strings.TrimSpace(value), 10, bitSize) if err == nil { return converted, ok } } return 0, false } func lookupStringListFromEnv(envName string) ([]string, bool) { value, ok := os.LookupEnv(envName) if ok { var result []string for v := range strings.SplitSeq(value, ",") { val := strings.TrimSpace(v) if val != "" { result = append(result, val) } } return result, true } return nil, false } ================================================ FILE: internal/config/config_darwin.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build darwin package config import "github.com/spf13/viper" // macOS specific config search path func setViperAdditionalConfigPaths() { viper.AddConfigPath("/usr/local/etc/sftpgo") } ================================================ FILE: internal/config/config_fallback.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !linux && !darwin package config func setViperAdditionalConfigPaths() {} ================================================ FILE: internal/config/config_linux.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build linux package config import "github.com/spf13/viper" // linux specific config search path func setViperAdditionalConfigPaths() { viper.AddConfigPath("$HOME/.config/sftpgo") viper.AddConfigPath("/etc/sftpgo") viper.AddConfigPath("/usr/local/etc/sftpgo") } ================================================ FILE: internal/config/config_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package config_test import ( "crypto/rand" "encoding/json" "os" "path/filepath" "slices" "testing" "github.com/sftpgo/sdk/kms" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) const ( tempConfigName = "temp" ) var ( configDir = filepath.Join(".", "..", "..") ) func reset() { viper.Reset() config.Init() } func TestLoadConfigTest(t *testing.T) { reset() err := config.LoadConfig(configDir, "") assert.NoError(t, err) assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig()) assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf()) assert.NotEqual(t, sftpd.Configuration{}, config.GetSFTPDConfig()) assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig()) assert.NotEqual(t, smtp.Config{}, config.GetSMTPConfig()) confName := tempConfigName + ".json" //nolint:goconst configFilePath := filepath.Join(configDir, confName) err = config.LoadConfig(configDir, confName) assert.Error(t, err) err = os.WriteFile(configFilePath, []byte("{invalid json}"), os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.Error(t, err) err = os.WriteFile(configFilePath, []byte(`{"sftpd": {"max_auth_tries": "a"}}`), os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.Error(t, err) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestLoadConfigFileNotFound(t *testing.T) { reset() viper.SetConfigName("configfile") err := config.LoadConfig(os.TempDir(), "") require.NoError(t, err) mfaConf := config.GetMFAConfig() require.Len(t, mfaConf.TOTP, 1) require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1) require.Len(t, config.GetCommonConfig().RateLimitersConfig[0].Protocols, 4) require.Len(t, config.GetHTTPDConfig().Bindings, 1) require.Len(t, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes, 3) } func TestReadEnvFiles(t *testing.T) { reset() envd := filepath.Join(configDir, "env.d") err := os.Mkdir(envd, os.ModePerm) assert.NoError(t, err) content := make([]byte, 1048576+1) _, err = rand.Read(content) assert.NoError(t, err) err = os.WriteFile(filepath.Join(envd, "env1"), []byte("SFTPGO_SFTPD__MAX_AUTH_TRIES = 10"), 0666) assert.NoError(t, err) err = os.WriteFile(filepath.Join(envd, "env2"), []byte(`{"invalid env": "value"}`), 0666) assert.NoError(t, err) err = os.WriteFile(filepath.Join(envd, "env3"), content, 0666) assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) assert.Equal(t, 10, config.GetSFTPDConfig().MaxAuthTries) _, ok := os.LookupEnv("SFTPGO_SFTPD__MAX_AUTH_TRIES") assert.True(t, ok) err = os.Unsetenv("SFTPGO_SFTPD__MAX_AUTH_TRIES") assert.NoError(t, err) os.RemoveAll(envd) } func TestEnabledSSHCommands(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) reset() sftpdConf := config.GetSFTPDConfig() sftpdConf.EnabledSSHCommands = []string{"scp"} c := make(map[string]sftpd.Configuration) c["sftpd"] = sftpdConf jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) sftpdConf = config.GetSFTPDConfig() if assert.Len(t, sftpdConf.EnabledSSHCommands, 1) { assert.Equal(t, "scp", sftpdConf.EnabledSSHCommands[0]) } err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidExternalAuthScope(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.ExternalAuthScope = 100 c := make(map[string]dataprovider.Config) c["data_provider"] = providerConf jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) assert.Equal(t, 0, config.GetProviderConf().ExternalAuthScope) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidProxyProtocol(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) commonConf := config.GetCommonConfig() commonConf.ProxyProtocol = 10 c := make(map[string]common.Configuration) c["common"] = commonConf jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) assert.Equal(t, 0, config.GetCommonConfig().ProxyProtocol) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidUsersBaseDir(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.UsersBaseDir = "." c := make(map[string]dataprovider.Config) c["data_provider"] = providerConf jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) assert.Empty(t, config.GetProviderConf().UsersBaseDir) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidInstallationHint(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) httpdConfig := config.GetHTTPDConfig() httpdConfig.Setup = httpd.SetupConfig{ InstallationCode: "abc", InstallationCodeHint: " ", } c := make(map[string]httpd.Conf) c["httpd"] = httpdConfig jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) httpdConfig = config.GetHTTPDConfig() assert.Equal(t, "abc", httpdConfig.Setup.InstallationCode) assert.Equal(t, "Installation code", httpdConfig.Setup.InstallationCodeHint) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidRenameMode(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) commonConfig := config.GetCommonConfig() commonConfig.RenameMode = 10 c := make(map[string]any) c["common"] = commonConfig jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) assert.Equal(t, 0, config.GetCommonConfig().RenameMode) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestDefenderProviderDriver(t *testing.T) { if config.GetProviderConf().Driver != dataprovider.SQLiteDataProviderName { t.Skip("this test is not supported with the current database provider") } reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) providerConf := config.GetProviderConf() providerConf.Driver = dataprovider.BoltDataProviderName commonConfig := config.GetCommonConfig() commonConfig.DefenderConfig.Enabled = true commonConfig.DefenderConfig.Driver = common.DefenderDriverProvider c := make(map[string]any) c["common"] = commonConfig c["data_provider"] = providerConf jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) assert.Equal(t, dataprovider.BoltDataProviderName, config.GetProviderConf().Driver) assert.Equal(t, common.DefenderDriverMemory, config.GetCommonConfig().DefenderConfig.Driver) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestSetGetConfig(t *testing.T) { reset() sftpdConf := config.GetSFTPDConfig() sftpdConf.MaxAuthTries = 10 config.SetSFTPDConfig(sftpdConf) assert.Equal(t, sftpdConf.MaxAuthTries, config.GetSFTPDConfig().MaxAuthTries) dataProviderConf := config.GetProviderConf() dataProviderConf.Host = "test host" config.SetProviderConf(dataProviderConf) assert.Equal(t, dataProviderConf.Host, config.GetProviderConf().Host) httpdConf := config.GetHTTPDConfig() httpdConf.Bindings = append(httpdConf.Bindings, httpd.Binding{Address: "0.0.0.0"}) config.SetHTTPDConfig(httpdConf) assert.Equal(t, httpdConf.Bindings[0].Address, config.GetHTTPDConfig().Bindings[0].Address) commonConf := config.GetCommonConfig() commonConf.IdleTimeout = 10 config.SetCommonConfig(commonConf) assert.Equal(t, commonConf.IdleTimeout, config.GetCommonConfig().IdleTimeout) ftpdConf := config.GetFTPDConfig() ftpdConf.CertificateFile = "cert" ftpdConf.CertificateKeyFile = "key" config.SetFTPDConfig(ftpdConf) assert.Equal(t, ftpdConf.CertificateFile, config.GetFTPDConfig().CertificateFile) assert.Equal(t, ftpdConf.CertificateKeyFile, config.GetFTPDConfig().CertificateKeyFile) webDavConf := config.GetWebDAVDConfig() webDavConf.CertificateFile = "dav_cert" webDavConf.CertificateKeyFile = "dav_key" config.SetWebDAVDConfig(webDavConf) assert.Equal(t, webDavConf.CertificateFile, config.GetWebDAVDConfig().CertificateFile) assert.Equal(t, webDavConf.CertificateKeyFile, config.GetWebDAVDConfig().CertificateKeyFile) kmsConf := config.GetKMSConfig() kmsConf.Secrets.MasterKeyPath = "apath" kmsConf.Secrets.URL = "aurl" config.SetKMSConfig(kmsConf) assert.Equal(t, kmsConf.Secrets.MasterKeyPath, config.GetKMSConfig().Secrets.MasterKeyPath) assert.Equal(t, kmsConf.Secrets.URL, config.GetKMSConfig().Secrets.URL) telemetryConf := config.GetTelemetryConfig() telemetryConf.BindPort = 10001 telemetryConf.BindAddress = "0.0.0.0" config.SetTelemetryConfig(telemetryConf) assert.Equal(t, telemetryConf.BindPort, config.GetTelemetryConfig().BindPort) assert.Equal(t, telemetryConf.BindAddress, config.GetTelemetryConfig().BindAddress) pluginConf := []plugin.Config{ { Type: "eventsearcher", }, } config.SetPluginsConfig(pluginConf) if assert.Len(t, config.GetPluginsConfig(), 1) { assert.Equal(t, pluginConf[0].Type, config.GetPluginsConfig()[0].Type) } assert.False(t, config.HasKMSPlugin()) pluginConf = []plugin.Config{ { Type: "notifier", }, { Type: "kms", }, } config.SetPluginsConfig(pluginConf) assert.Len(t, config.GetPluginsConfig(), 2) assert.True(t, config.HasKMSPlugin()) } func TestServiceToStart(t *testing.T) { reset() err := config.LoadConfig(configDir, "") assert.NoError(t, err) assert.True(t, config.HasServicesToStart()) sftpdConf := config.GetSFTPDConfig() sftpdConf.Bindings[0].Port = 0 config.SetSFTPDConfig(sftpdConf) // httpd service is enabled assert.True(t, config.HasServicesToStart()) httpdConf := config.GetHTTPDConfig() httpdConf.Bindings[0].Port = 0 assert.False(t, config.HasServicesToStart()) ftpdConf := config.GetFTPDConfig() ftpdConf.Bindings[0].Port = 2121 config.SetFTPDConfig(ftpdConf) assert.True(t, config.HasServicesToStart()) ftpdConf.Bindings[0].Port = 0 config.SetFTPDConfig(ftpdConf) webdavdConf := config.GetWebDAVDConfig() webdavdConf.Bindings[0].Port = 9000 config.SetWebDAVDConfig(webdavdConf) assert.True(t, config.HasServicesToStart()) webdavdConf.Bindings[0].Port = 0 config.SetWebDAVDConfig(webdavdConf) assert.False(t, config.HasServicesToStart()) sftpdConf.Bindings[0].Port = 2022 config.SetSFTPDConfig(sftpdConf) assert.True(t, config.HasServicesToStart()) } func TestSSHCommandsFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS", "cd,scp") t.Cleanup(func() { os.Unsetenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) sftpdConf := config.GetSFTPDConfig() if assert.Len(t, sftpdConf.EnabledSSHCommands, 2) { assert.Equal(t, "cd", sftpdConf.EnabledSSHCommands[0]) assert.Equal(t, "scp", sftpdConf.EnabledSSHCommands[1]) } } func TestSMTPFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_SMTP__HOST", "smtp.example.com") os.Setenv("SFTPGO_SMTP__PORT", "587") t.Cleanup(func() { os.Unsetenv("SFTPGO_SMTP__HOST") os.Unsetenv("SFTPGO_SMTP__PORT") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) smtpConfig := config.GetSMTPConfig() assert.Equal(t, "smtp.example.com", smtpConfig.Host) assert.Equal(t, 587, smtpConfig.Port) } func TestMFAFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_MFA__TOTP__0__NAME", "main") os.Setenv("SFTPGO_MFA__TOTP__1__NAME", "additional_name") os.Setenv("SFTPGO_MFA__TOTP__1__ISSUER", "additional_issuer") os.Setenv("SFTPGO_MFA__TOTP__1__ALGO", "sha256") t.Cleanup(func() { os.Unsetenv("SFTPGO_MFA__TOTP__0__NAME") os.Unsetenv("SFTPGO_MFA__TOTP__1__NAME") os.Unsetenv("SFTPGO_MFA__TOTP__1__ISSUER") os.Unsetenv("SFTPGO_MFA__TOTP__1__ALGO") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) mfaConf := config.GetMFAConfig() require.Len(t, mfaConf.TOTP, 2) require.Equal(t, "main", mfaConf.TOTP[0].Name) require.Equal(t, "SFTPGo", mfaConf.TOTP[0].Issuer) require.Equal(t, "sha1", mfaConf.TOTP[0].Algo) require.Equal(t, "additional_name", mfaConf.TOTP[1].Name) require.Equal(t, "additional_issuer", mfaConf.TOTP[1].Issuer) require.Equal(t, "sha256", mfaConf.TOTP[1].Algo) } func TestDisabledMFAConfig(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) mfaConf := config.GetMFAConfig() assert.Len(t, mfaConf.TOTP, 1) reset() c := make(map[string]mfa.Config) c["mfa"] = mfa.Config{} jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) mfaConf = config.GetMFAConfig() assert.Len(t, mfaConf.TOTP, 0) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestOverrideSliceValues(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) c := make(map[string]any) c["common"] = common.Configuration{ RateLimitersConfig: []common.RateLimiterConfig{ { Type: 1, Protocols: []string{"HTTP"}, }, }, } jsonConf, err := json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1) require.Equal(t, []string{"HTTP"}, config.GetCommonConfig().RateLimitersConfig[0].Protocols) reset() // empty ratelimiters, default value should be used c["common"] = common.Configuration{} jsonConf, err = json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1) rl := config.GetCommonConfig().RateLimitersConfig[0] require.Equal(t, []string{"SSH", "FTP", "DAV", "HTTP"}, rl.Protocols) require.Equal(t, int64(1000), rl.Period) reset() c = make(map[string]any) c["httpd"] = httpd.Conf{ Bindings: []httpd.Binding{ { OIDC: httpd.OIDC{ Scopes: []string{"scope1"}, }, }, }, } jsonConf, err = json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) require.Len(t, config.GetHTTPDConfig().Bindings, 1) require.Equal(t, []string{"scope1"}, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes) reset() c = make(map[string]any) c["httpd"] = httpd.Conf{ Bindings: nil, } jsonConf, err = json.Marshal(c) assert.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) require.Len(t, config.GetHTTPDConfig().Bindings, 1) require.Equal(t, []string{"openid", "profile", "email"}, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes) } func TestFTPDOverridesFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP", "192.168.1.1") os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__NETWORKS", "192.168.1.0/24, 192.168.3.0/25") os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__IP", "192.168.2.1") os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__NETWORKS", "192.168.2.0/24") cleanup := func() { os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__NETWORKS") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__IP") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__NETWORKS") } t.Cleanup(cleanup) err := config.LoadConfig(configDir, "") assert.NoError(t, err) ftpdConf := config.GetFTPDConfig() require.Len(t, ftpdConf.Bindings, 1) require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides, 2) require.Equal(t, "192.168.1.1", ftpdConf.Bindings[0].PassiveIPOverrides[0].IP) require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides[0].Networks, 2) require.Equal(t, "192.168.2.1", ftpdConf.Bindings[0].PassiveIPOverrides[1].IP) require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides[1].Networks, 1) cleanup() cfg := make(map[string]any) cfg["ftpd"] = ftpdConf configAsJSON, err := json.Marshal(cfg) require.NoError(t, err) confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err = os.WriteFile(configFilePath, configAsJSON, os.ModePerm) assert.NoError(t, err) os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP", "192.168.1.2") os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__NETWORKS", "192.168.2.0/24,192.168.4.0/25") err = config.LoadConfig(configDir, confName) assert.NoError(t, err) ftpdConf = config.GetFTPDConfig() require.Len(t, ftpdConf.Bindings, 1) require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides, 2) require.Equal(t, "192.168.1.2", ftpdConf.Bindings[0].PassiveIPOverrides[0].IP) require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides[0].Networks, 2) require.Equal(t, "192.168.2.1", ftpdConf.Bindings[0].PassiveIPOverrides[1].IP) require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides[1].Networks, 2) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestHTTPDSubObjectsFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__KEY", "X-Forwarded-Proto") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__VALUE", "https") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_ID", "client_id") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET", "client_secret") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET_FILE", "client_secret_file") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CONFIG_URL", "config_url") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__REDIRECT_BASE_URL", "redirect_base_url") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__USERNAME_FIELD", "email") cleanup := func() { os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__KEY") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__VALUE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_ID") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET_FILE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CONFIG_URL") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__REDIRECT_BASE_URL") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__USERNAME_FIELD") } t.Cleanup(cleanup) err := config.LoadConfig(configDir, "") assert.NoError(t, err) httpdConf := config.GetHTTPDConfig() require.Len(t, httpdConf.Bindings, 1) require.Len(t, httpdConf.Bindings[0].Security.HTTPSProxyHeaders, 1) require.Equal(t, "client_id", httpdConf.Bindings[0].OIDC.ClientID) require.Equal(t, "client_secret", httpdConf.Bindings[0].OIDC.ClientSecret) require.Equal(t, "client_secret_file", httpdConf.Bindings[0].OIDC.ClientSecretFile) require.Equal(t, "config_url", httpdConf.Bindings[0].OIDC.ConfigURL) require.Equal(t, "redirect_base_url", httpdConf.Bindings[0].OIDC.RedirectBaseURL) require.Equal(t, "email", httpdConf.Bindings[0].OIDC.UsernameField) cleanup() cfg := make(map[string]any) cfg["httpd"] = httpdConf configAsJSON, err := json.Marshal(cfg) require.NoError(t, err) confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err = os.WriteFile(configFilePath, configAsJSON, os.ModePerm) assert.NoError(t, err) os.Setenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__VALUE", "http") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET", "new_client_secret") err = config.LoadConfig(configDir, confName) assert.NoError(t, err) httpdConf = config.GetHTTPDConfig() require.Len(t, httpdConf.Bindings, 1) require.Len(t, httpdConf.Bindings[0].Security.HTTPSProxyHeaders, 1) require.Equal(t, "http", httpdConf.Bindings[0].Security.HTTPSProxyHeaders[0].Value) require.Equal(t, "client_id", httpdConf.Bindings[0].OIDC.ClientID) require.Equal(t, "new_client_secret", httpdConf.Bindings[0].OIDC.ClientSecret) require.Equal(t, "config_url", httpdConf.Bindings[0].OIDC.ConfigURL) require.Equal(t, "redirect_base_url", httpdConf.Bindings[0].OIDC.RedirectBaseURL) require.Equal(t, "email", httpdConf.Bindings[0].OIDC.UsernameField) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestPluginsFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_PLUGINS__0__TYPE", "notifier") os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__FS_EVENTS", "upload,download") os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__PROVIDER_EVENTS", "add,update") os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__PROVIDER_OBJECTS", "user,admin") os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__LOG_EVENTS", "a,1,2") os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__RETRY_MAX_TIME", "2") os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__RETRY_QUEUE_MAX_SIZE", "1000") os.Setenv("SFTPGO_PLUGINS__0__CMD", "plugin_start_cmd") os.Setenv("SFTPGO_PLUGINS__0__ARGS", "arg1,arg2") os.Setenv("SFTPGO_PLUGINS__0__SHA256SUM", "0a71ded61fccd59c4f3695b51c1b3d180da8d2d77ea09ccee20dac242675c193") os.Setenv("SFTPGO_PLUGINS__0__AUTO_MTLS", "1") os.Setenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__SCHEME", kms.SchemeAWS) os.Setenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__ENCRYPTED_STATUS", kms.SecretStatusAWS) os.Setenv("SFTPGO_PLUGINS__0__AUTH_OPTIONS__SCOPE", "14") os.Setenv("SFTPGO_PLUGINS__0__ENV_PREFIX", "prefix_") os.Setenv("SFTPGO_PLUGINS__0__ENV_VARS", "a, b") t.Cleanup(func() { os.Unsetenv("SFTPGO_PLUGINS__0__TYPE") os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__FS_EVENTS") os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__PROVIDER_EVENTS") os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__PROVIDER_OBJECTS") os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__LOG_EVENTS") os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__RETRY_MAX_TIME") os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__RETRY_QUEUE_MAX_SIZE") os.Unsetenv("SFTPGO_PLUGINS__0__CMD") os.Unsetenv("SFTPGO_PLUGINS__0__ARGS") os.Unsetenv("SFTPGO_PLUGINS__0__SHA256SUM") os.Unsetenv("SFTPGO_PLUGINS__0__AUTO_MTLS") os.Unsetenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__SCHEME") os.Unsetenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__ENCRYPTED_STATUS") os.Unsetenv("SFTPGO_PLUGINS__0__AUTH_OPTIONS__SCOPE") os.Unsetenv("SFTPGO_PLUGINS__0__ENV_PREFIX") os.Unsetenv("SFTPGO_PLUGINS__0__ENV_VARS") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) pluginsConf := config.GetPluginsConfig() require.Len(t, pluginsConf, 1) pluginConf := pluginsConf[0] require.Equal(t, "notifier", pluginConf.Type) require.Len(t, pluginConf.NotifierOptions.FsEvents, 2) require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "upload")) require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "download")) require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2) require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0]) require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1]) require.Len(t, pluginConf.NotifierOptions.ProviderObjects, 2) require.Equal(t, "user", pluginConf.NotifierOptions.ProviderObjects[0]) require.Equal(t, "admin", pluginConf.NotifierOptions.ProviderObjects[1]) require.Len(t, pluginConf.NotifierOptions.LogEvents, 2) require.Equal(t, 1, pluginConf.NotifierOptions.LogEvents[0]) require.Equal(t, 2, pluginConf.NotifierOptions.LogEvents[1]) require.Equal(t, 2, pluginConf.NotifierOptions.RetryMaxTime) require.Equal(t, 1000, pluginConf.NotifierOptions.RetryQueueMaxSize) require.Equal(t, "plugin_start_cmd", pluginConf.Cmd) require.Len(t, pluginConf.Args, 2) require.Equal(t, "arg1", pluginConf.Args[0]) require.Equal(t, "arg2", pluginConf.Args[1]) require.Equal(t, "0a71ded61fccd59c4f3695b51c1b3d180da8d2d77ea09ccee20dac242675c193", pluginConf.SHA256Sum) require.True(t, pluginConf.AutoMTLS) require.Equal(t, kms.SchemeAWS, pluginConf.KMSOptions.Scheme) require.Equal(t, kms.SecretStatusAWS, pluginConf.KMSOptions.EncryptedStatus) require.Equal(t, 14, pluginConf.AuthOptions.Scope) require.Equal(t, "prefix_", pluginConf.EnvPrefix) require.Len(t, pluginConf.EnvVars, 2) assert.Equal(t, "a", pluginConf.EnvVars[0]) assert.Equal(t, "b", pluginConf.EnvVars[1]) cfg := make(map[string]any) cfg["plugins"] = pluginConf configAsJSON, err := json.Marshal(cfg) require.NoError(t, err) confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err = os.WriteFile(configFilePath, configAsJSON, os.ModePerm) assert.NoError(t, err) os.Setenv("SFTPGO_PLUGINS__0__CMD", "plugin_start_cmd1") os.Setenv("SFTPGO_PLUGINS__0__ARGS", "") os.Setenv("SFTPGO_PLUGINS__0__AUTO_MTLS", "0") os.Setenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__SCHEME", kms.SchemeVaultTransit) os.Setenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__ENCRYPTED_STATUS", kms.SecretStatusVaultTransit) os.Setenv("SFTPGO_PLUGINS__0__ENV_PREFIX", "") os.Setenv("SFTPGO_PLUGINS__0__ENV_VARS", "") err = config.LoadConfig(configDir, confName) assert.NoError(t, err) pluginsConf = config.GetPluginsConfig() require.Len(t, pluginsConf, 1) pluginConf = pluginsConf[0] require.Equal(t, "notifier", pluginConf.Type) require.Len(t, pluginConf.NotifierOptions.FsEvents, 2) require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "upload")) require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "download")) require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2) require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0]) require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1]) require.Len(t, pluginConf.NotifierOptions.ProviderObjects, 2) require.Equal(t, "user", pluginConf.NotifierOptions.ProviderObjects[0]) require.Equal(t, "admin", pluginConf.NotifierOptions.ProviderObjects[1]) require.Equal(t, 2, pluginConf.NotifierOptions.RetryMaxTime) require.Equal(t, 1000, pluginConf.NotifierOptions.RetryQueueMaxSize) require.Equal(t, "plugin_start_cmd1", pluginConf.Cmd) require.Len(t, pluginConf.Args, 0) require.Equal(t, "0a71ded61fccd59c4f3695b51c1b3d180da8d2d77ea09ccee20dac242675c193", pluginConf.SHA256Sum) require.False(t, pluginConf.AutoMTLS) require.Equal(t, kms.SchemeVaultTransit, pluginConf.KMSOptions.Scheme) require.Equal(t, kms.SecretStatusVaultTransit, pluginConf.KMSOptions.EncryptedStatus) require.Equal(t, 14, pluginConf.AuthOptions.Scope) assert.Empty(t, pluginConf.EnvPrefix) assert.Len(t, pluginConf.EnvVars, 0) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestRateLimitersFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__AVERAGE", "100") os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__PERIOD", "2000") os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__BURST", "10") os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__TYPE", "2") os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__PROTOCOLS", "SSH, FTP") os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__GENERATE_DEFENDER_EVENTS", "1") os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_SOFT_LIMIT", "50") os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_HARD_LIMIT", "100") os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__8__AVERAGE", "50") t.Cleanup(func() { os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__AVERAGE") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__PERIOD") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__BURST") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__TYPE") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__PROTOCOLS") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__GENERATE_DEFENDER_EVENTS") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_SOFT_LIMIT") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_HARD_LIMIT") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__8__AVERAGE") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) limiters := config.GetCommonConfig().RateLimitersConfig require.Len(t, limiters, 2) require.Equal(t, int64(100), limiters[0].Average) require.Equal(t, int64(2000), limiters[0].Period) require.Equal(t, 10, limiters[0].Burst) require.Equal(t, 2, limiters[0].Type) protocols := limiters[0].Protocols require.Len(t, protocols, 2) require.True(t, slices.Contains(protocols, common.ProtocolFTP)) require.True(t, slices.Contains(protocols, common.ProtocolSSH)) require.True(t, limiters[0].GenerateDefenderEvents) require.Equal(t, 50, limiters[0].EntriesSoftLimit) require.Equal(t, 100, limiters[0].EntriesHardLimit) require.Equal(t, int64(50), limiters[1].Average) // we check the default values here require.Equal(t, int64(1000), limiters[1].Period) require.Equal(t, 1, limiters[1].Burst) require.Equal(t, 2, limiters[1].Type) protocols = limiters[1].Protocols require.Len(t, protocols, 4) require.True(t, slices.Contains(protocols, common.ProtocolFTP)) require.True(t, slices.Contains(protocols, common.ProtocolSSH)) require.True(t, slices.Contains(protocols, common.ProtocolWebDAV)) require.True(t, slices.Contains(protocols, common.ProtocolHTTP)) require.False(t, limiters[1].GenerateDefenderEvents) require.Equal(t, 100, limiters[1].EntriesSoftLimit) require.Equal(t, 150, limiters[1].EntriesHardLimit) } func TestSFTPDBindingsFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS", "127.0.0.1") os.Setenv("SFTPGO_SFTPD__BINDINGS__0__PORT", "2200") os.Setenv("SFTPGO_SFTPD__BINDINGS__0__APPLY_PROXY_CONFIG", "false") os.Setenv("SFTPGO_SFTPD__BINDINGS__3__ADDRESS", "127.0.1.1") os.Setenv("SFTPGO_SFTPD__BINDINGS__3__PORT", "2203") t.Cleanup(func() { os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS") os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__PORT") os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__APPLY_PROXY_CONFIG") os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__ADDRESS") os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__PORT") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) bindings := config.GetSFTPDConfig().Bindings require.Len(t, bindings, 2) require.Equal(t, 2200, bindings[0].Port) require.Equal(t, "127.0.0.1", bindings[0].Address) require.False(t, bindings[0].ApplyProxyConfig) require.Equal(t, 2203, bindings[1].Port) require.Equal(t, "127.0.1.1", bindings[1].Address) require.True(t, bindings[1].ApplyProxyConfig) // default value } func TestCommandsFromEnv(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) commandConfig := config.GetCommandConfig() commandConfig.Commands = append(commandConfig.Commands, command.Command{ Path: "cmd", Timeout: 10, Env: []string{"a=a"}, }) c := make(map[string]command.Config) c["command"] = commandConfig jsonConf, err := json.Marshal(c) require.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) require.NoError(t, err) err = config.LoadConfig(configDir, confName) require.NoError(t, err) commandConfig = config.GetCommandConfig() require.Equal(t, 30, commandConfig.Timeout) require.Len(t, commandConfig.Env, 0) require.Len(t, commandConfig.Commands, 1) require.Equal(t, "cmd", commandConfig.Commands[0].Path) require.Equal(t, 10, commandConfig.Commands[0].Timeout) require.Equal(t, []string{"a=a"}, commandConfig.Commands[0].Env) os.Setenv("SFTPGO_COMMAND__TIMEOUT", "25") os.Setenv("SFTPGO_COMMAND__ENV", "a=b,c=d") os.Setenv("SFTPGO_COMMAND__COMMANDS__0__PATH", "cmd1") os.Setenv("SFTPGO_COMMAND__COMMANDS__0__TIMEOUT", "11") os.Setenv("SFTPGO_COMMAND__COMMANDS__1__PATH", "cmd2") os.Setenv("SFTPGO_COMMAND__COMMANDS__1__TIMEOUT", "20") os.Setenv("SFTPGO_COMMAND__COMMANDS__1__ENV", "e=f") os.Setenv("SFTPGO_COMMAND__COMMANDS__1__ARGS", "arg1, arg2") t.Cleanup(func() { os.Unsetenv("SFTPGO_COMMAND__TIMEOUT") os.Unsetenv("SFTPGO_COMMAND__ENV") os.Unsetenv("SFTPGO_COMMAND__COMMANDS__0__PATH") os.Unsetenv("SFTPGO_COMMAND__COMMANDS__0__TIMEOUT") os.Unsetenv("SFTPGO_COMMAND__COMMANDS__1__PATH") os.Unsetenv("SFTPGO_COMMAND__COMMANDS__1__TIMEOUT") os.Unsetenv("SFTPGO_COMMAND__COMMANDS__1__ENV") os.Unsetenv("SFTPGO_COMMAND__COMMANDS__1__ARGS") }) err = config.LoadConfig(configDir, confName) assert.NoError(t, err) commandConfig = config.GetCommandConfig() require.Equal(t, 25, commandConfig.Timeout) require.Equal(t, []string{"a=b", "c=d"}, commandConfig.Env) require.Len(t, commandConfig.Commands, 2) require.Equal(t, "cmd1", commandConfig.Commands[0].Path) require.Equal(t, 11, commandConfig.Commands[0].Timeout) require.Equal(t, []string{"a=a"}, commandConfig.Commands[0].Env) require.Equal(t, "cmd2", commandConfig.Commands[1].Path) require.Equal(t, 20, commandConfig.Commands[1].Timeout) require.Equal(t, []string{"e=f"}, commandConfig.Commands[1].Env) require.Equal(t, []string{"arg1", "arg2"}, commandConfig.Commands[1].Args) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestFTPDBindingsFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_FTPD__BINDINGS__0__ADDRESS", "127.0.0.1") os.Setenv("SFTPGO_FTPD__BINDINGS__0__PORT", "2200") os.Setenv("SFTPGO_FTPD__BINDINGS__0__APPLY_PROXY_CONFIG", "f") os.Setenv("SFTPGO_FTPD__BINDINGS__0__TLS_MODE", "2") os.Setenv("SFTPGO_FTPD__BINDINGS__0__FORCE_PASSIVE_IP", "127.0.1.2") os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP", "172.16.1.1") os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_HOST", "127.0.1.3") os.Setenv("SFTPGO_FTPD__BINDINGS__0__TLS_CIPHER_SUITES", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_CONNECTIONS_SECURITY", "1") os.Setenv("SFTPGO_FTPD__BINDINGS__9__ADDRESS", "127.0.1.1") os.Setenv("SFTPGO_FTPD__BINDINGS__9__PORT", "2203") os.Setenv("SFTPGO_FTPD__BINDINGS__9__TLS_MODE", "1") os.Setenv("SFTPGO_FTPD__BINDINGS__9__MIN_TLS_VERSION", "13") os.Setenv("SFTPGO_FTPD__BINDINGS__9__FORCE_PASSIVE_IP", "127.0.1.1") os.Setenv("SFTPGO_FTPD__BINDINGS__9__PASSIVE_IP_OVERRIDES__3__IP", "192.168.1.1") os.Setenv("SFTPGO_FTPD__BINDINGS__9__PASSIVE_IP_OVERRIDES__3__NETWORKS", "192.168.1.0/24, 192.168.3.0/25") os.Setenv("SFTPGO_FTPD__BINDINGS__9__CLIENT_AUTH_TYPE", "2") os.Setenv("SFTPGO_FTPD__BINDINGS__9__DEBUG", "1") os.Setenv("SFTPGO_FTPD__BINDINGS__9__ACTIVE_CONNECTIONS_SECURITY", "1") os.Setenv("SFTPGO_FTPD__BINDINGS__9__IGNORE_ASCII_TRANSFER_TYPE", "1") os.Setenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_FILE", "cert.crt") os.Setenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_KEY_FILE", "cert.key") t.Cleanup(func() { os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__ADDRESS") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PORT") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__APPLY_PROXY_CONFIG") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__TLS_MODE") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__FORCE_PASSIVE_IP") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_HOST") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__TLS_CIPHER_SUITES") os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__ACTIVE_CONNECTIONS_SECURITY") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__ADDRESS") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__PORT") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__TLS_MODE") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__MIN_TLS_VERSION") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__FORCE_PASSIVE_IP") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__PASSIVE_IP_OVERRIDES__3__IP") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__PASSIVE_IP_OVERRIDES__3__NETWORKS") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CLIENT_AUTH_TYPE") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__DEBUG") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__ACTIVE_CONNECTIONS_SECURITY") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__IGNORE_ASCII_TRANSFER_TYPE") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_FILE") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_KEY_FILE") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) bindings := config.GetFTPDConfig().Bindings require.Len(t, bindings, 2) require.Equal(t, 2200, bindings[0].Port) require.Equal(t, "127.0.0.1", bindings[0].Address) require.False(t, bindings[0].ApplyProxyConfig) require.Equal(t, 2, bindings[0].TLSMode) require.Equal(t, 12, bindings[0].MinTLSVersion) require.Equal(t, "127.0.1.2", bindings[0].ForcePassiveIP) require.Len(t, bindings[0].PassiveIPOverrides, 0) require.Equal(t, "127.0.1.3", bindings[0].PassiveHost) require.Equal(t, 0, bindings[0].ClientAuthType) require.Len(t, bindings[0].TLSCipherSuites, 2) require.Equal(t, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", bindings[0].TLSCipherSuites[0]) require.Equal(t, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", bindings[0].TLSCipherSuites[1]) require.False(t, bindings[0].Debug) require.Equal(t, 1, bindings[0].PassiveConnectionsSecurity) require.Equal(t, 0, bindings[0].ActiveConnectionsSecurity) require.Equal(t, 2203, bindings[1].Port) require.Equal(t, "127.0.1.1", bindings[1].Address) require.True(t, bindings[1].ApplyProxyConfig) // default value require.Equal(t, 1, bindings[1].TLSMode) require.Equal(t, 13, bindings[1].MinTLSVersion) require.Equal(t, "127.0.1.1", bindings[1].ForcePassiveIP) require.Empty(t, bindings[1].PassiveHost) require.Len(t, bindings[1].PassiveIPOverrides, 1) require.Equal(t, "192.168.1.1", bindings[1].PassiveIPOverrides[0].IP) require.Len(t, bindings[1].PassiveIPOverrides[0].Networks, 2) require.Equal(t, "192.168.1.0/24", bindings[1].PassiveIPOverrides[0].Networks[0]) require.Equal(t, "192.168.3.0/25", bindings[1].PassiveIPOverrides[0].Networks[1]) require.Equal(t, 2, bindings[1].ClientAuthType) require.Nil(t, bindings[1].TLSCipherSuites) require.Equal(t, 0, bindings[1].PassiveConnectionsSecurity) require.Equal(t, 1, bindings[1].ActiveConnectionsSecurity) require.True(t, bindings[1].Debug) require.Equal(t, "cert.crt", bindings[1].CertificateFile) require.Equal(t, "cert.key", bindings[1].CertificateKeyFile) } func TestWebDAVMimeCache(t *testing.T) { reset() err := config.LoadConfig(configDir, "") assert.NoError(t, err) webdavdConf := config.GetWebDAVDConfig() webdavdConf.Cache.MimeTypes.CustomMappings = []webdavd.CustomMimeMapping{ { Ext: ".custom", Mime: "application/custom", }, } cfg := map[string]any{ "webdavd": webdavdConf, } data, err := json.Marshal(cfg) assert.NoError(t, err) confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err = os.WriteFile(configFilePath, data, 0666) assert.NoError(t, err) reset() err = config.LoadConfig(configDir, confName) assert.NoError(t, err) mappings := config.GetWebDAVDConfig().Cache.MimeTypes.CustomMappings if assert.Len(t, mappings, 1) { assert.Equal(t, ".custom", mappings[0].Ext) assert.Equal(t, "application/custom", mappings[0].Mime) } // now add from env os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__1__EXT", ".custom1") os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__1__MIME", "application/custom1") t.Cleanup(func() { os.Unsetenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__EXT") os.Unsetenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__MIME") os.Unsetenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__1__EXT") os.Unsetenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__1__MIME") }) reset() err = config.LoadConfig(configDir, confName) assert.NoError(t, err) mappings = config.GetWebDAVDConfig().Cache.MimeTypes.CustomMappings if assert.Len(t, mappings, 2) { assert.Equal(t, ".custom", mappings[0].Ext) assert.Equal(t, "application/custom", mappings[0].Mime) assert.Equal(t, ".custom1", mappings[1].Ext) assert.Equal(t, "application/custom1", mappings[1].Mime) } // override from env os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__EXT", ".custom0") os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__MIME", "application/custom0") reset() err = config.LoadConfig(configDir, confName) assert.NoError(t, err) mappings = config.GetWebDAVDConfig().Cache.MimeTypes.CustomMappings if assert.Len(t, mappings, 2) { assert.Equal(t, ".custom0", mappings[0].Ext) assert.Equal(t, "application/custom0", mappings[0].Mime) assert.Equal(t, ".custom1", mappings[1].Ext) assert.Equal(t, "application/custom1", mappings[1].Mime) } err = os.Remove(configFilePath) assert.NoError(t, err) } func TestWebDAVBindingsFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ADDRESS", "127.0.0.1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT", "8000") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS", "0") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES", "TLS_RSA_WITH_AES_128_CBC_SHA ") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_PROTOCOLS", "http/1.1 ") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_MODE", "1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED", "192.168.10.1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_PROXY_HEADER", "X-Forwarded-For") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_HEADER_DEPTH", "2") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS", "127.0.1.1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT", "9000") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS", "1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__MIN_TLS_VERSION", "13") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CLIENT_AUTH_TYPE", "1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX", "/dav2") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE", "webdav.crt") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE", "webdav.key") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__DISABLE_WWW_AUTH_HEADER", "1") t.Cleanup(func() { os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ADDRESS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_PROTOCOLS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_MODE") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_PROXY_HEADER") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_HEADER_DEPTH") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__MIN_TLS_VERSION") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CLIENT_AUTH_TYPE") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__DISABLE_WWW_AUTH_HEADER") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) bindings := config.GetWebDAVDConfig().Bindings require.Len(t, bindings, 3) require.Equal(t, 0, bindings[0].Port) require.Empty(t, bindings[0].Address) require.False(t, bindings[0].EnableHTTPS) require.Equal(t, 12, bindings[0].MinTLSVersion) require.Len(t, bindings[0].TLSCipherSuites, 0) require.Len(t, bindings[0].Protocols, 0) require.Equal(t, 0, bindings[0].ProxyMode) require.Empty(t, bindings[0].Prefix) require.Equal(t, 0, bindings[0].ClientIPHeaderDepth) require.False(t, bindings[0].DisableWWWAuthHeader) require.Equal(t, 8000, bindings[1].Port) require.Equal(t, "127.0.0.1", bindings[1].Address) require.False(t, bindings[1].EnableHTTPS) require.Equal(t, 12, bindings[1].MinTLSVersion) require.Equal(t, 0, bindings[1].ClientAuthType) require.Len(t, bindings[1].TLSCipherSuites, 1) require.Equal(t, "TLS_RSA_WITH_AES_128_CBC_SHA", bindings[1].TLSCipherSuites[0]) require.Len(t, bindings[1].Protocols, 1) assert.Equal(t, "http/1.1", bindings[1].Protocols[0]) require.Equal(t, 1, bindings[1].ProxyMode) require.Equal(t, "192.168.10.1", bindings[1].ProxyAllowed[0]) require.Equal(t, "X-Forwarded-For", bindings[1].ClientIPProxyHeader) require.Equal(t, 2, bindings[1].ClientIPHeaderDepth) require.Empty(t, bindings[1].Prefix) require.False(t, bindings[1].DisableWWWAuthHeader) require.Equal(t, 9000, bindings[2].Port) require.Equal(t, "127.0.1.1", bindings[2].Address) require.True(t, bindings[2].EnableHTTPS) require.Equal(t, 13, bindings[2].MinTLSVersion) require.Equal(t, 1, bindings[2].ClientAuthType) require.Equal(t, 0, bindings[2].ProxyMode) require.Nil(t, bindings[2].TLSCipherSuites) require.Equal(t, "/dav2", bindings[2].Prefix) require.Equal(t, "webdav.crt", bindings[2].CertificateFile) require.Equal(t, "webdav.key", bindings[2].CertificateKeyFile) require.Equal(t, 0, bindings[2].ClientIPHeaderDepth) require.True(t, bindings[2].DisableWWWAuthHeader) } func TestHTTPDBindingsFromEnv(t *testing.T) { reset() sockPath := filepath.Clean(os.TempDir()) os.Setenv("SFTPGO_HTTPD__BINDINGS__0__ADDRESS", sockPath) os.Setenv("SFTPGO_HTTPD__BINDINGS__0__PORT", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__0__TLS_CIPHER_SUITES", " TLS_AES_128_GCM_SHA256") os.Setenv("SFTPGO_HTTPD__BINDINGS__1__ADDRESS", "127.0.0.1") os.Setenv("SFTPGO_HTTPD__BINDINGS__1__PORT", "8000") os.Setenv("SFTPGO_HTTPD__BINDINGS__1__ENABLE_HTTPS", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__1__HIDE_LOGIN_URL", " 1") os.Setenv("SFTPGO_HTTPD__BINDINGS__1__BRANDING__WEB_ADMIN__NAME", "Web Admin") os.Setenv("SFTPGO_HTTPD__BINDINGS__1__BRANDING__WEB_CLIENT__SHORT_NAME", "WebClient") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ADDRESS", "127.0.1.1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PORT", "9000") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS", "3") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__DISABLED_LOGIN_METHODS", "12") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BASE_URL", "https://example.com") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__LANGUAGES", "en,es") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1 ") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__MIN_TLS_VERSION", "13") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES", " TLS_AES_256_GCM_SHA384 , TLS_CHACHA20_POLY1305_SHA256") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_PROTOCOLS", "h2, http/1.1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_MODE", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED", " 192.168.9.1 , 172.16.25.0/24") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_PROXY_HEADER", "X-Real-IP") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_HEADER_DEPTH", "2") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__HIDE_LOGIN_URL", "3") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CLIENT_ID", "client id") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CLIENT_SECRET", "client secret") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CONFIG_URL", "config url") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__REDIRECT_BASE_URL", "redirect base url") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__USERNAME_FIELD", "preferred_username") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__ROLE_FIELD", "sftpgo_role") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES", "openid") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS", "field1,field2") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__INSECURE_SKIP_SIGNATURE_CHECK", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED", "true") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS", "*.example.com,*.example.net") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS_ARE_REGEX", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HOSTS_PROXY_HEADERS", "X-Forwarded-Host") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_REDIRECT", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_HOST", "www.example.com") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_PROXY_HEADERS__1__KEY", "X-Forwarded-Proto") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_PROXY_HEADERS__1__VALUE", "https") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_SECONDS", "31536000") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_INCLUDE_SUBDOMAINS", "false") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_PRELOAD", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CONTENT_TYPE_NOSNIFF", "t") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CONTENT_SECURITY_POLICY", "script-src $NONCE") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__PERMISSIONS_POLICY", "fullscreen=(), geolocation=()") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_OPENER_POLICY", "same-origin") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_RESOURCE_POLICY", "same-site") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_EMBEDDER_POLICY", "require-corp") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CACHE_CONTROL", "private") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__REFERRER_POLICY", "no-referrer") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__EXTRA_CSS__0__PATH", "path1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__EXTRA_CSS__1__PATH", "path2") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_ADMIN__FAVICON_PATH", "favicon.ico") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__LOGO_PATH", "logo.png") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__DISCLAIMER_NAME", "disclaimer") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_ADMIN__DISCLAIMER_PATH", "disclaimer.html") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__DEFAULT_CSS", "default.css") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__EXTRA_CSS", "1.css,2.css") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_FILE", "httpd.crt") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_KEY_FILE", "httpd.key") t.Cleanup(func() { os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__ADDRESS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__PORT") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__TLS_CIPHER_SUITES") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__ADDRESS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__PORT") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__ENABLE_HTTPS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__HIDE_LOGIN_URL") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__BRANDING__WEB_ADMIN__NAME") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__BRANDING__WEB_CLIENT__SHORT_NAME") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__EXTRA_CSS__0__PATH") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ADDRESS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PORT") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__MIN_TLS_VERSION") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__DISABLED_LOGIN_METHODS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BASE_URL") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__LANGUAGES") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_PROTOCOLS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_MODE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_PROXY_HEADER") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_HEADER_DEPTH") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__HIDE_LOGIN_URL") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CLIENT_ID") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CLIENT_SECRET") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CONFIG_URL") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__REDIRECT_BASE_URL") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__USERNAME_FIELD") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__ROLE_FIELD") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__INSECURE_SKIP_SIGNATURE_CHECK") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS_ARE_REGEX") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HOSTS_PROXY_HEADERS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_REDIRECT") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_HOST") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_PROXY_HEADERS__1__KEY") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_PROXY_HEADERS__1__VALUE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_SECONDS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_INCLUDE_SUBDOMAINS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_PRELOAD") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CONTENT_TYPE_NOSNIFF") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CONTENT_SECURITY_POLICY") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__PERMISSIONS_POLICY") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_OPENER_POLICY") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_RESOURCE_POLICY") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_EMBEDDER_POLICY") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CACHE_CONTROL") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__REFERRER_POLICY") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__EXTRA_CSS__0__PATH") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__EXTRA_CSS__1__PATH") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_ADMIN__FAVICON_PATH") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__LOGO_PATH") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__DISCLAIMER_NAME") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_ADMIN__DISCLAIMER_PATH") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__DEFAULT_CSS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__EXTRA_CSS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_FILE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_KEY_FILE") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) bindings := config.GetHTTPDConfig().Bindings require.Len(t, bindings, 3) require.Equal(t, 0, bindings[0].Port) require.Equal(t, sockPath, bindings[0].Address) require.False(t, bindings[0].EnableHTTPS) require.Len(t, bindings[0].Protocols, 0) require.Equal(t, 12, bindings[0].MinTLSVersion) require.True(t, bindings[0].EnableWebAdmin) require.True(t, bindings[0].EnableWebClient) require.True(t, bindings[0].EnableRESTAPI) require.Equal(t, 0, bindings[0].EnabledLoginMethods) require.Equal(t, 0, bindings[0].DisabledLoginMethods) require.True(t, bindings[0].RenderOpenAPI) require.Empty(t, bindings[0].BaseURL) require.Len(t, bindings[0].Languages, 1) assert.Contains(t, bindings[0].Languages, "en") require.Len(t, bindings[0].TLSCipherSuites, 1) require.Equal(t, 0, bindings[0].ProxyMode) require.Empty(t, bindings[0].OIDC.ConfigURL) require.Equal(t, "TLS_AES_128_GCM_SHA256", bindings[0].TLSCipherSuites[0]) require.Equal(t, 0, bindings[0].HideLoginURL) require.False(t, bindings[0].Security.Enabled) require.Equal(t, 0, bindings[0].ClientIPHeaderDepth) require.Len(t, bindings[0].OIDC.Scopes, 3) require.False(t, bindings[0].OIDC.InsecureSkipSignatureCheck) require.False(t, bindings[0].OIDC.Debug) require.Empty(t, bindings[0].Security.ReferrerPolicy) require.Equal(t, 8000, bindings[1].Port) require.Equal(t, "127.0.0.1", bindings[1].Address) require.False(t, bindings[1].EnableHTTPS) require.Equal(t, 12, bindings[0].MinTLSVersion) require.True(t, bindings[1].EnableWebAdmin) require.True(t, bindings[1].EnableWebClient) require.True(t, bindings[1].EnableRESTAPI) require.Equal(t, 0, bindings[1].EnabledLoginMethods) require.Equal(t, 0, bindings[1].DisabledLoginMethods) require.True(t, bindings[1].RenderOpenAPI) require.Empty(t, bindings[1].BaseURL) require.Len(t, bindings[1].Languages, 1) assert.Contains(t, bindings[1].Languages, "en") require.Nil(t, bindings[1].TLSCipherSuites) require.Equal(t, 1, bindings[1].HideLoginURL) require.Empty(t, bindings[1].OIDC.ClientID) require.Len(t, bindings[1].OIDC.Scopes, 3) require.False(t, bindings[1].OIDC.InsecureSkipSignatureCheck) require.False(t, bindings[1].OIDC.Debug) require.False(t, bindings[1].Security.Enabled) require.Equal(t, "Web Admin", bindings[1].Branding.WebAdmin.Name) require.Equal(t, "WebClient", bindings[1].Branding.WebClient.ShortName) require.Equal(t, 0, bindings[1].ProxyMode) require.Equal(t, 0, bindings[1].ClientIPHeaderDepth) require.Equal(t, 9000, bindings[2].Port) require.Equal(t, "127.0.1.1", bindings[2].Address) require.True(t, bindings[2].EnableHTTPS) require.Equal(t, 13, bindings[2].MinTLSVersion) require.False(t, bindings[2].EnableWebAdmin) require.False(t, bindings[2].EnableWebClient) require.False(t, bindings[2].EnableRESTAPI) require.Equal(t, 3, bindings[2].EnabledLoginMethods) require.Equal(t, 12, bindings[2].DisabledLoginMethods) require.False(t, bindings[2].RenderOpenAPI) require.Equal(t, "https://example.com", bindings[2].BaseURL) require.Len(t, bindings[2].Languages, 2) assert.Contains(t, bindings[2].Languages, "en") assert.Contains(t, bindings[2].Languages, "es") require.Equal(t, 1, bindings[2].ClientAuthType) require.Len(t, bindings[2].TLSCipherSuites, 2) require.Equal(t, "TLS_AES_256_GCM_SHA384", bindings[2].TLSCipherSuites[0]) require.Equal(t, "TLS_CHACHA20_POLY1305_SHA256", bindings[2].TLSCipherSuites[1]) require.Len(t, bindings[2].Protocols, 2) require.Equal(t, "h2", bindings[2].Protocols[0]) require.Equal(t, "http/1.1", bindings[2].Protocols[1]) require.Equal(t, 1, bindings[2].ProxyMode) require.Len(t, bindings[2].ProxyAllowed, 2) require.Equal(t, "192.168.9.1", bindings[2].ProxyAllowed[0]) require.Equal(t, "172.16.25.0/24", bindings[2].ProxyAllowed[1]) require.Equal(t, "X-Real-IP", bindings[2].ClientIPProxyHeader) require.Equal(t, 2, bindings[2].ClientIPHeaderDepth) require.Equal(t, 3, bindings[2].HideLoginURL) require.Equal(t, "client id", bindings[2].OIDC.ClientID) require.Equal(t, "client secret", bindings[2].OIDC.ClientSecret) require.Equal(t, "config url", bindings[2].OIDC.ConfigURL) require.Equal(t, "redirect base url", bindings[2].OIDC.RedirectBaseURL) require.Equal(t, "preferred_username", bindings[2].OIDC.UsernameField) require.Equal(t, "sftpgo_role", bindings[2].OIDC.RoleField) require.Len(t, bindings[2].OIDC.Scopes, 1) require.Equal(t, "openid", bindings[2].OIDC.Scopes[0]) require.True(t, bindings[2].OIDC.ImplicitRoles) require.Len(t, bindings[2].OIDC.CustomFields, 2) require.Equal(t, "field1", bindings[2].OIDC.CustomFields[0]) require.Equal(t, "field2", bindings[2].OIDC.CustomFields[1]) require.True(t, bindings[2].OIDC.InsecureSkipSignatureCheck) require.True(t, bindings[2].OIDC.Debug) require.True(t, bindings[2].Security.Enabled) require.Len(t, bindings[2].Security.AllowedHosts, 2) require.Equal(t, "*.example.com", bindings[2].Security.AllowedHosts[0]) require.Equal(t, "*.example.net", bindings[2].Security.AllowedHosts[1]) require.True(t, bindings[2].Security.AllowedHostsAreRegex) require.Len(t, bindings[2].Security.HostsProxyHeaders, 1) require.Equal(t, "X-Forwarded-Host", bindings[2].Security.HostsProxyHeaders[0]) require.True(t, bindings[2].Security.HTTPSRedirect) require.Equal(t, "www.example.com", bindings[2].Security.HTTPSHost) require.Len(t, bindings[2].Security.HTTPSProxyHeaders, 1) require.Equal(t, "X-Forwarded-Proto", bindings[2].Security.HTTPSProxyHeaders[0].Key) require.Equal(t, "https", bindings[2].Security.HTTPSProxyHeaders[0].Value) require.Equal(t, int64(31536000), bindings[2].Security.STSSeconds) require.False(t, bindings[2].Security.STSIncludeSubdomains) require.False(t, bindings[2].Security.STSPreload) require.True(t, bindings[2].Security.ContentTypeNosniff) require.Equal(t, "script-src $NONCE", bindings[2].Security.ContentSecurityPolicy) require.Equal(t, "fullscreen=(), geolocation=()", bindings[2].Security.PermissionsPolicy) require.Equal(t, "same-origin", bindings[2].Security.CrossOriginOpenerPolicy) require.Equal(t, "same-site", bindings[2].Security.CrossOriginResourcePolicy) require.Equal(t, "require-corp", bindings[2].Security.CrossOriginEmbedderPolicy) require.Equal(t, "private", bindings[2].Security.CacheControl) require.Equal(t, "no-referrer", bindings[2].Security.ReferrerPolicy) require.Equal(t, "favicon.ico", bindings[2].Branding.WebAdmin.FaviconPath) require.Equal(t, "logo.png", bindings[2].Branding.WebClient.LogoPath) require.Equal(t, "disclaimer", bindings[2].Branding.WebClient.DisclaimerName) require.Equal(t, "disclaimer.html", bindings[2].Branding.WebAdmin.DisclaimerPath) require.Equal(t, []string{"default.css"}, bindings[2].Branding.WebClient.DefaultCSS) require.Len(t, bindings[2].Branding.WebClient.ExtraCSS, 2) require.Equal(t, "1.css", bindings[2].Branding.WebClient.ExtraCSS[0]) require.Equal(t, "2.css", bindings[2].Branding.WebClient.ExtraCSS[1]) require.Equal(t, "httpd.crt", bindings[2].CertificateFile) require.Equal(t, "httpd.key", bindings[2].CertificateKeyFile) } func TestHTTPClientCertificatesFromEnv(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) httpConf := config.GetHTTPConfig() httpConf.Certificates = append(httpConf.Certificates, httpclient.TLSKeyPair{ Cert: "cert", Key: "key", }) c := make(map[string]httpclient.Config) c["http"] = httpConf jsonConf, err := json.Marshal(c) require.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) require.NoError(t, err) err = config.LoadConfig(configDir, confName) require.NoError(t, err) require.Len(t, config.GetHTTPConfig().Certificates, 1) require.Equal(t, "cert", config.GetHTTPConfig().Certificates[0].Cert) require.Equal(t, "key", config.GetHTTPConfig().Certificates[0].Key) os.Setenv("SFTPGO_HTTP__CERTIFICATES__0__CERT", "cert0") os.Setenv("SFTPGO_HTTP__CERTIFICATES__0__KEY", "key0") os.Setenv("SFTPGO_HTTP__CERTIFICATES__8__CERT", "cert8") os.Setenv("SFTPGO_HTTP__CERTIFICATES__9__CERT", "cert9") os.Setenv("SFTPGO_HTTP__CERTIFICATES__9__KEY", "key9") t.Cleanup(func() { os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__0__CERT") os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__0__KEY") os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__8__CERT") os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__9__CERT") os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__9__KEY") }) err = config.LoadConfig(configDir, confName) require.NoError(t, err) require.Len(t, config.GetHTTPConfig().Certificates, 2) require.Equal(t, "cert0", config.GetHTTPConfig().Certificates[0].Cert) require.Equal(t, "key0", config.GetHTTPConfig().Certificates[0].Key) require.Equal(t, "cert9", config.GetHTTPConfig().Certificates[1].Cert) require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key) err = os.Remove(configFilePath) assert.NoError(t, err) config.Init() err = config.LoadConfig(configDir, "") require.NoError(t, err) require.Len(t, config.GetHTTPConfig().Certificates, 2) require.Equal(t, "cert0", config.GetHTTPConfig().Certificates[0].Cert) require.Equal(t, "key0", config.GetHTTPConfig().Certificates[0].Key) require.Equal(t, "cert9", config.GetHTTPConfig().Certificates[1].Cert) require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key) } func TestHTTPClientHeadersFromEnv(t *testing.T) { reset() confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err := config.LoadConfig(configDir, "") assert.NoError(t, err) httpConf := config.GetHTTPConfig() httpConf.Headers = append(httpConf.Headers, httpclient.Header{ Key: "key", Value: "value", URL: "url", }) c := make(map[string]httpclient.Config) c["http"] = httpConf jsonConf, err := json.Marshal(c) require.NoError(t, err) err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) require.NoError(t, err) err = config.LoadConfig(configDir, confName) require.NoError(t, err) require.Len(t, config.GetHTTPConfig().Headers, 1) require.Equal(t, "key", config.GetHTTPConfig().Headers[0].Key) require.Equal(t, "value", config.GetHTTPConfig().Headers[0].Value) require.Equal(t, "url", config.GetHTTPConfig().Headers[0].URL) os.Setenv("SFTPGO_HTTP__HEADERS__0__KEY", "key0") os.Setenv("SFTPGO_HTTP__HEADERS__0__VALUE", "value0") os.Setenv("SFTPGO_HTTP__HEADERS__0__URL", "url0") os.Setenv("SFTPGO_HTTP__HEADERS__8__KEY", "key8") os.Setenv("SFTPGO_HTTP__HEADERS__9__KEY", "key9") os.Setenv("SFTPGO_HTTP__HEADERS__9__VALUE", "value9") os.Setenv("SFTPGO_HTTP__HEADERS__9__URL", "url9") t.Cleanup(func() { os.Unsetenv("SFTPGO_HTTP__HEADERS__0__KEY") os.Unsetenv("SFTPGO_HTTP__HEADERS__0__VALUE") os.Unsetenv("SFTPGO_HTTP__HEADERS__0__URL") os.Unsetenv("SFTPGO_HTTP__HEADERS__8__KEY") os.Unsetenv("SFTPGO_HTTP__HEADERS__9__KEY") os.Unsetenv("SFTPGO_HTTP__HEADERS__9__VALUE") os.Unsetenv("SFTPGO_HTTP__HEADERS__9__URL") }) err = config.LoadConfig(configDir, confName) require.NoError(t, err) require.Len(t, config.GetHTTPConfig().Headers, 2) require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key) require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value) require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL) require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key) require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value) require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL) err = os.Remove(configFilePath) assert.NoError(t, err) config.Init() err = config.LoadConfig(configDir, "") require.NoError(t, err) require.Len(t, config.GetHTTPConfig().Headers, 2) require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key) require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value) require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL) require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key) require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value) require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL) } func TestConfigFromEnv(t *testing.T) { reset() os.Setenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS", "127.0.0.1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__0__PORT", "12000") os.Setenv("SFTPGO_DATA_PROVIDER__PASSWORD_HASHING__ARGON2_OPTIONS__ITERATIONS", "41") os.Setenv("SFTPGO_DATA_PROVIDER__POOL_SIZE", "10") os.Setenv("SFTPGO_DATA_PROVIDER__IS_SHARED", "1") os.Setenv("SFTPGO_DATA_PROVIDER__ACTIONS__EXECUTE_ON", "add") os.Setenv("SFTPGO_KMS__SECRETS__URL", "local") os.Setenv("SFTPGO_KMS__SECRETS__MASTER_KEY_PATH", "path") os.Setenv("SFTPGO_TELEMETRY__TLS_CIPHER_SUITES", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA") os.Setenv("SFTPGO_TELEMETRY__TLS_PROTOCOLS", "h2") os.Setenv("SFTPGO_HTTPD__SETUP__INSTALLATION_CODE", "123") os.Setenv("SFTPGO_ACME__HTTP01_CHALLENGE__PORT", "5002") t.Cleanup(func() { os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__0__PORT") os.Unsetenv("SFTPGO_DATA_PROVIDER__PASSWORD_HASHING__ARGON2_OPTIONS__ITERATIONS") os.Unsetenv("SFTPGO_DATA_PROVIDER__POOL_SIZE") os.Unsetenv("SFTPGO_DATA_PROVIDER__IS_SHARED") os.Unsetenv("SFTPGO_DATA_PROVIDER__ACTIONS__EXECUTE_ON") os.Unsetenv("SFTPGO_KMS__SECRETS__URL") os.Unsetenv("SFTPGO_KMS__SECRETS__MASTER_KEY_PATH") os.Unsetenv("SFTPGO_TELEMETRY__TLS_CIPHER_SUITES") os.Unsetenv("SFTPGO_TELEMETRY__TLS_PROTOCOLS") os.Unsetenv("SFTPGO_HTTPD__SETUP__INSTALLATION_CODE") os.Unsetenv("SFTPGO_ACME__HTTP01_CHALLENGE_PORT") }) err := config.LoadConfig(configDir, "") assert.NoError(t, err) sftpdConfig := config.GetSFTPDConfig() assert.Equal(t, "127.0.0.1", sftpdConfig.Bindings[0].Address) assert.Equal(t, 12000, config.GetWebDAVDConfig().Bindings[0].Port) dataProviderConf := config.GetProviderConf() assert.Equal(t, uint32(41), dataProviderConf.PasswordHashing.Argon2Options.Iterations) assert.Equal(t, 10, dataProviderConf.PoolSize) assert.Equal(t, 1, dataProviderConf.IsShared) assert.Len(t, dataProviderConf.Actions.ExecuteOn, 1) assert.Contains(t, dataProviderConf.Actions.ExecuteOn, "add") kmsConfig := config.GetKMSConfig() assert.Equal(t, "local", kmsConfig.Secrets.URL) assert.Equal(t, "path", kmsConfig.Secrets.MasterKeyPath) telemetryConfig := config.GetTelemetryConfig() require.Len(t, telemetryConfig.TLSCipherSuites, 2) assert.Equal(t, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", telemetryConfig.TLSCipherSuites[0]) assert.Equal(t, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", telemetryConfig.TLSCipherSuites[1]) require.Len(t, telemetryConfig.Protocols, 1) assert.Equal(t, "h2", telemetryConfig.Protocols[0]) assert.Equal(t, "123", config.GetHTTPDConfig().Setup.InstallationCode) acmeConfig := config.GetACMEConfig() assert.Equal(t, 5002, acmeConfig.HTTP01Challenge.Port) } ================================================ FILE: internal/dataprovider/actions.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "bytes" "context" "fmt" "net/url" "os/exec" "path/filepath" "slices" "strings" "time" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( // ActionExecutorSelf is used as username for self action, for example a user/admin that updates itself ActionExecutorSelf = "__self__" // ActionExecutorSystem is used as username for actions with no explicit executor associated, for example // adding/updating a user/admin by loading initial data ActionExecutorSystem = "__system__" ) const ( actionObjectUser = "user" actionObjectFolder = "folder" actionObjectGroup = "group" actionObjectAdmin = "admin" actionObjectAPIKey = "api_key" actionObjectShare = "share" actionObjectEventAction = "event_action" actionObjectEventRule = "event_rule" actionObjectRole = "role" actionObjectIPListEntry = "ip_list_entry" actionObjectConfigs = "configs" ) var ( actionsConcurrencyGuard = make(chan struct{}, 100) reservedUsers = []string{ActionExecutorSelf, ActionExecutorSystem} ) func executeAction(operation, executor, ip, objectType, objectName, role string, object plugin.Renderer) { if plugin.Handler.HasNotifiers() { plugin.Handler.NotifyProviderEvent(¬ifier.ProviderEvent{ Action: operation, Username: executor, ObjectType: objectType, ObjectName: objectName, IP: ip, Role: role, Timestamp: time.Now().UnixNano(), }, object) } if fnHandleRuleForProviderEvent != nil { fnHandleRuleForProviderEvent(operation, executor, ip, objectType, objectName, role, object) } if config.Actions.Hook == "" { return } if !slices.Contains(config.Actions.ExecuteOn, operation) || !slices.Contains(config.Actions.ExecuteFor, objectType) { return } go func() { actionsConcurrencyGuard <- struct{}{} defer func() { <-actionsConcurrencyGuard }() dataAsJSON, err := object.RenderAsJSON(operation != operationDelete) if err != nil { providerLog(logger.LevelError, "unable to serialize user as JSON for operation %q: %v", operation, err) return } if strings.HasPrefix(config.Actions.Hook, "http") { var url *url.URL url, err := url.Parse(config.Actions.Hook) if err != nil { providerLog(logger.LevelError, "Invalid http_notification_url %q for operation %q: %v", config.Actions.Hook, operation, err) return } q := url.Query() q.Add("action", operation) q.Add("username", executor) q.Add("ip", ip) q.Add("object_type", objectType) q.Add("object_name", objectName) if role != "" { q.Add("role", role) } q.Add("timestamp", fmt.Sprintf("%d", time.Now().UnixNano())) url.RawQuery = q.Encode() startTime := time.Now() resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(dataAsJSON)) respCode := 0 if err == nil { respCode = resp.StatusCode resp.Body.Close() } providerLog(logger.LevelDebug, "notified operation %q to URL: %s status code: %d, elapsed: %s err: %v", operation, url.Redacted(), respCode, time.Since(startTime), err) return } executeNotificationCommand(operation, executor, ip, objectType, objectName, role, dataAsJSON) //nolint:errcheck // the error is used in test cases only }() } func executeNotificationCommand(operation, executor, ip, objectType, objectName, role string, objectAsJSON []byte) error { if !filepath.IsAbs(config.Actions.Hook) { err := fmt.Errorf("invalid notification command %q", config.Actions.Hook) logger.Warn(logSender, "", "unable to execute notification command: %v", err) return err } timeout, env, args := command.GetConfig(config.Actions.Hook, command.HookProviderActions) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, config.Actions.Hook, args...) cmd.Env = append(env, fmt.Sprintf("SFTPGO_PROVIDER_ACTION=%s", operation), fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_TYPE=%s", objectType), fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_NAME=%s", objectName), fmt.Sprintf("SFTPGO_PROVIDER_USERNAME=%s", executor), fmt.Sprintf("SFTPGO_PROVIDER_IP=%s", ip), fmt.Sprintf("SFTPGO_PROVIDER_ROLE=%s", role), fmt.Sprintf("SFTPGO_PROVIDER_TIMESTAMP=%d", util.GetTimeAsMsSinceEpoch(time.Now())), fmt.Sprintf("SFTPGO_PROVIDER_OBJECT=%s", objectAsJSON)) startTime := time.Now() err := cmd.Run() providerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v", config.Actions.Hook, time.Since(startTime), err) return err } ================================================ FILE: internal/dataprovider/admin.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "encoding/json" "errors" "fmt" "net" "os" "slices" "strconv" "strings" "github.com/alexedwards/argon2id" "github.com/sftpgo/sdk" passwordvalidator "github.com/wagslane/go-password-validator" "golang.org/x/crypto/bcrypt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/util" ) // Available permissions for SFTPGo admins const ( PermAdminAny = "*" PermAdminAddUsers = "add_users" PermAdminChangeUsers = "edit_users" PermAdminDeleteUsers = "del_users" PermAdminViewUsers = "view_users" PermAdminViewConnections = "view_conns" PermAdminCloseConnections = "close_conns" PermAdminViewServerStatus = "view_status" PermAdminManageGroups = "manage_groups" PermAdminManageFolders = "manage_folders" PermAdminQuotaScans = "quota_scans" PermAdminManageDefender = "manage_defender" PermAdminViewDefender = "view_defender" PermAdminViewEvents = "view_events" PermAdminDisableMFA = "disable_mfa" ) const ( // GroupAddToUsersAsMembership defines that the admin's group will be added as membership group for new users GroupAddToUsersAsMembership = iota // GroupAddToUsersAsPrimary defines that the admin's group will be added as primary group for new users GroupAddToUsersAsPrimary // GroupAddToUsersAsSecondary defines that the admin's group will be added as secondary group for new users GroupAddToUsersAsSecondary ) var ( validAdminPerms = []string{PermAdminAny, PermAdminAddUsers, PermAdminChangeUsers, PermAdminDeleteUsers, PermAdminViewUsers, PermAdminManageFolders, PermAdminManageGroups, PermAdminViewConnections, PermAdminCloseConnections, PermAdminViewServerStatus, PermAdminQuotaScans, PermAdminManageDefender, PermAdminViewDefender, PermAdminViewEvents, PermAdminDisableMFA} forbiddenPermsForRoleAdmins = []string{PermAdminAny} ) // AdminTOTPConfig defines the time-based one time password configuration type AdminTOTPConfig struct { Enabled bool `json:"enabled,omitempty"` ConfigName string `json:"config_name,omitempty"` Secret *kms.Secret `json:"secret,omitempty"` } func (c *AdminTOTPConfig) validate(username string) error { if !c.Enabled { c.ConfigName = "" c.Secret = kms.NewEmptySecret() return nil } if c.ConfigName == "" { return util.NewValidationError("totp: config name is mandatory") } if !slices.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) { return util.NewValidationError(fmt.Sprintf("totp: config name %q not found", c.ConfigName)) } if c.Secret.IsEmpty() { return util.NewValidationError("totp: secret is mandatory") } if c.Secret.IsPlain() { c.Secret.SetAdditionalData(username) if err := c.Secret.Encrypt(); err != nil { return util.NewValidationError(fmt.Sprintf("totp: unable to encrypt secret: %v", err)) } } return nil } // AdminPreferences defines the admin preferences type AdminPreferences struct { // Allow to hide some sections from the user page. // These are not security settings and are not enforced server side // in any way. They are only intended to simplify the user page in // the WebAdmin UI. // // 1 means hide groups section // 2 means hide filesystem section, "users_base_dir" must be set in the config file otherwise this setting is ignored // 4 means hide virtual folders section // 8 means hide profile section // 16 means hide ACLs section // 32 means hide disk and bandwidth quota limits section // 64 means hide advanced settings section // // The settings can be combined HideUserPageSections int `json:"hide_user_page_sections,omitempty"` // Defines the default expiration for newly created users as number of days. // 0 means no expiration DefaultUsersExpiration int `json:"default_users_expiration,omitempty"` } // HideGroups returns true if the groups section should be hidden func (p *AdminPreferences) HideGroups() bool { return p.HideUserPageSections&1 != 0 } // HideFilesystem returns true if the filesystem section should be hidden func (p *AdminPreferences) HideFilesystem() bool { return config.UsersBaseDir != "" && p.HideUserPageSections&2 != 0 } // HideVirtualFolders returns true if the virtual folder section should be hidden func (p *AdminPreferences) HideVirtualFolders() bool { return p.HideUserPageSections&4 != 0 } // HideProfile returns true if the profile section should be hidden func (p *AdminPreferences) HideProfile() bool { return p.HideUserPageSections&8 != 0 } // HideACLs returns true if the ACLs section should be hidden func (p *AdminPreferences) HideACLs() bool { return p.HideUserPageSections&16 != 0 } // HideDiskQuotaAndBandwidthLimits returns true if the disk quota and bandwidth limits // section should be hidden func (p *AdminPreferences) HideDiskQuotaAndBandwidthLimits() bool { return p.HideUserPageSections&32 != 0 } // HideAdvancedSettings returns true if the advanced settings section should be hidden func (p *AdminPreferences) HideAdvancedSettings() bool { return p.HideUserPageSections&64 != 0 } // VisibleUserPageSections returns the number of visible sections // in the user page func (p *AdminPreferences) VisibleUserPageSections() int { var result int if !p.HideProfile() { result++ } if !p.HideACLs() { result++ } if !p.HideDiskQuotaAndBandwidthLimits() { result++ } if !p.HideAdvancedSettings() { result++ } return result } // AdminFilters defines additional restrictions for SFTPGo admins // TODO: rename to AdminOptions in v3 type AdminFilters struct { // only clients connecting from these IP/Mask are allowed. // IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291 // for example "192.0.2.0/24" or "2001:db8::/32" AllowList []string `json:"allow_list,omitempty"` // API key auth allows to impersonate this administrator with an API key AllowAPIKeyAuth bool `json:"allow_api_key_auth,omitempty"` // A password change is required at the next login RequirePasswordChange bool `json:"require_password_change,omitempty"` // Require two factor authentication RequireTwoFactor bool `json:"require_two_factor"` // Time-based one time passwords configuration TOTPConfig AdminTOTPConfig `json:"totp_config,omitempty"` // Recovery codes to use if the user loses access to their second factor auth device. // Each code can only be used once, you should use these codes to login and disable or // reset 2FA for your account RecoveryCodes []RecoveryCode `json:"recovery_codes,omitempty"` Preferences AdminPreferences `json:"preferences"` } // AdminGroupMappingOptions defines the options for admin/group mapping type AdminGroupMappingOptions struct { AddToUsersAs int `json:"add_to_users_as,omitempty"` } func (o *AdminGroupMappingOptions) validate() error { if o.AddToUsersAs < GroupAddToUsersAsMembership || o.AddToUsersAs > GroupAddToUsersAsSecondary { return util.NewValidationError(fmt.Sprintf("Invalid mode to add groups to new users: %d", o.AddToUsersAs)) } return nil } // GetUserGroupType returns the type for the matching user group func (o *AdminGroupMappingOptions) GetUserGroupType() int { switch o.AddToUsersAs { case GroupAddToUsersAsPrimary: return sdk.GroupTypePrimary case GroupAddToUsersAsSecondary: return sdk.GroupTypeSecondary default: return sdk.GroupTypeMembership } } // AdminGroupMapping defines the mapping between an SFTPGo admin and a group type AdminGroupMapping struct { Name string `json:"name"` Options AdminGroupMappingOptions `json:"options"` } // Admin defines a SFTPGo admin type Admin struct { // Database unique identifier ID int64 `json:"id"` // 1 enabled, 0 disabled (login is not allowed) Status int `json:"status"` // Username Username string `json:"username"` Password string `json:"password,omitempty"` Email string `json:"email,omitempty"` Permissions []string `json:"permissions"` Filters AdminFilters `json:"filters,omitempty"` Description string `json:"description,omitempty"` AdditionalInfo string `json:"additional_info,omitempty"` // Groups membership Groups []AdminGroupMapping `json:"groups,omitempty"` // Creation time as unix timestamp in milliseconds. It will be 0 for admins created before v2.2.0 CreatedAt int64 `json:"created_at"` // last update time as unix timestamp in milliseconds UpdatedAt int64 `json:"updated_at"` // Last login as unix timestamp in milliseconds LastLogin int64 `json:"last_login"` // Role name. If set the admin can only administer users with the same role. // Role admins cannot be super administrators Role string `json:"role,omitempty"` } // CountUnusedRecoveryCodes returns the number of unused recovery codes func (a *Admin) CountUnusedRecoveryCodes() int { unused := 0 for _, code := range a.Filters.RecoveryCodes { if !code.Used { unused++ } } return unused } func (a *Admin) hashPassword() error { if a.Password != "" && !util.IsStringPrefixInSlice(a.Password, internalHashPwdPrefixes) { if config.PasswordValidation.Admins.MinEntropy > 0 { if err := passwordvalidator.Validate(a.Password, config.PasswordValidation.Admins.MinEntropy); err != nil { return util.NewI18nError(util.NewValidationError(err.Error()), util.I18nErrorPasswordComplexity) } } if config.PasswordHashing.Algo == HashingAlgoBcrypt { pwd, err := bcrypt.GenerateFromPassword([]byte(a.Password), config.PasswordHashing.BcryptOptions.Cost) if err != nil { return err } a.Password = util.BytesToString(pwd) } else { pwd, err := argon2id.CreateHash(a.Password, argon2Params) if err != nil { return err } a.Password = pwd } } return nil } func (a *Admin) hasRedactedSecret() bool { return a.Filters.TOTPConfig.Secret.IsRedacted() } func (a *Admin) validateRecoveryCodes() error { for i := 0; i < len(a.Filters.RecoveryCodes); i++ { code := &a.Filters.RecoveryCodes[i] if code.Secret.IsEmpty() { return util.NewValidationError("mfa: recovery code cannot be empty") } if code.Secret.IsPlain() { code.Secret.SetAdditionalData(a.Username) if err := code.Secret.Encrypt(); err != nil { return util.NewValidationError(fmt.Sprintf("mfa: unable to encrypt recovery code: %v", err)) } } } return nil } func (a *Admin) validatePermissions() error { a.Permissions = util.RemoveDuplicates(a.Permissions, false) if len(a.Permissions) == 0 { return util.NewI18nError( util.NewValidationError("please grant some permissions to this admin"), util.I18nErrorPermissionsRequired, ) } if slices.Contains(a.Permissions, PermAdminAny) { a.Permissions = []string{PermAdminAny} } for _, perm := range a.Permissions { if !slices.Contains(validAdminPerms, perm) { return util.NewValidationError(fmt.Sprintf("invalid permission: %q", perm)) } if a.Role != "" { if slices.Contains(forbiddenPermsForRoleAdmins, perm) { return util.NewI18nError( util.NewValidationError("a role admin cannot be a super admin"), util.I18nErrorRoleAdminPerms, ) } } } return nil } func (a *Admin) validateGroups() error { hasPrimary := false for _, g := range a.Groups { if g.Name == "" { return util.NewValidationError("group name is mandatory") } if err := g.Options.validate(); err != nil { return err } if g.Options.AddToUsersAs == GroupAddToUsersAsPrimary { if hasPrimary { return util.NewI18nError( util.NewValidationError("only one primary group is allowed"), util.I18nErrorPrimaryGroup, ) } hasPrimary = true } } return nil } func (a *Admin) applyNamingRules() { a.Username = config.convertName(a.Username) a.Role = config.convertName(a.Role) for idx := range a.Groups { a.Groups[idx].Name = config.convertName(a.Groups[idx].Name) } } func (a *Admin) validate() error { //nolint:gocyclo a.SetEmptySecretsIfNil() a.applyNamingRules() a.Password = strings.TrimSpace(a.Password) if a.Username == "" { return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired) } if !util.IsNameValid(a.Username) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if err := checkReservedUsernames(a.Username); err != nil { return util.NewI18nError(err, util.I18nErrorReservedUsername) } if a.Password == "" { return util.NewI18nError(util.NewValidationError("please set a password"), util.I18nErrorPasswordRequired) } if a.hasRedactedSecret() { return util.NewValidationError("cannot save an admin with a redacted secret") } if err := a.Filters.TOTPConfig.validate(a.Username); err != nil { return util.NewI18nError(err, util.I18nError2FAInvalid) } if err := a.validateRecoveryCodes(); err != nil { return util.NewI18nError(err, util.I18nErrorRecoveryCodesInvalid) } if config.NamingRules&1 == 0 && !usernameRegex.MatchString(a.Username) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("username %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", a.Username)), util.I18nErrorInvalidUser, ) } if err := a.hashPassword(); err != nil { return err } if err := a.validatePermissions(); err != nil { return err } if a.Email != "" && !util.IsEmailValid(a.Email) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("email %q is not valid", a.Email)), util.I18nErrorInvalidEmail, ) } a.Filters.AllowList = util.RemoveDuplicates(a.Filters.AllowList, false) for _, IPMask := range a.Filters.AllowList { _, _, err := net.ParseCIDR(IPMask) if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not parse allow list entry %q : %v", IPMask, err)), util.I18nErrorInvalidIPMask, ) } } return a.validateGroups() } // CheckPassword verifies the admin password func (a *Admin) CheckPassword(password string) (bool, error) { if config.PasswordCaching { found, match := cachedAdminPasswords.Check(a.Username, password, a.Password) if found { if !match { return false, ErrInvalidCredentials } return match, nil } } if strings.HasPrefix(a.Password, bcryptPwdPrefix) { if err := bcrypt.CompareHashAndPassword([]byte(a.Password), []byte(password)); err != nil { return false, ErrInvalidCredentials } cachedAdminPasswords.Add(a.Username, password, a.Password) return true, nil } match, err := argon2id.ComparePasswordAndHash(password, a.Password) if !match || err != nil { return false, ErrInvalidCredentials } if match { cachedAdminPasswords.Add(a.Username, password, a.Password) } return match, err } // CanLoginFromIP returns true if login from the given IP is allowed func (a *Admin) CanLoginFromIP(ip string) bool { if len(a.Filters.AllowList) == 0 { return true } parsedIP := net.ParseIP(ip) if parsedIP == nil { return len(a.Filters.AllowList) == 0 } for _, ipMask := range a.Filters.AllowList { _, network, err := net.ParseCIDR(ipMask) if err != nil { continue } if network.Contains(parsedIP) { return true } } return false } // CanLogin returns an error if the login is not allowed func (a *Admin) CanLogin(ip string) error { if a.Status != 1 { return fmt.Errorf("admin %q is disabled", a.Username) } if !a.CanLoginFromIP(ip) { return fmt.Errorf("login from IP %v not allowed", ip) } return nil } func (a *Admin) checkUserAndPass(password, ip string) error { if err := a.CanLogin(ip); err != nil { return err } if a.Password == "" || strings.TrimSpace(password) == "" { return errors.New("credentials cannot be null or empty") } match, err := a.CheckPassword(password) if err != nil { return err } if !match { return ErrInvalidCredentials } return nil } // RenderAsJSON implements the renderer interface used within plugins func (a *Admin) RenderAsJSON(reload bool) ([]byte, error) { if reload { admin, err := provider.adminExists(a.Username) if err != nil { providerLog(logger.LevelError, "unable to reload admin before rendering as json: %v", err) return nil, err } admin.HideConfidentialData() return json.Marshal(admin) } a.HideConfidentialData() return json.Marshal(a) } // HideConfidentialData hides admin confidential data func (a *Admin) HideConfidentialData() { a.Password = "" if a.Filters.TOTPConfig.Secret != nil { a.Filters.TOTPConfig.Secret.Hide() } for _, code := range a.Filters.RecoveryCodes { if code.Secret != nil { code.Secret.Hide() } } a.SetNilSecretsIfEmpty() } // SetEmptySecretsIfNil sets the secrets to empty if nil func (a *Admin) SetEmptySecretsIfNil() { if a.Filters.TOTPConfig.Secret == nil { a.Filters.TOTPConfig.Secret = kms.NewEmptySecret() } } // SetNilSecretsIfEmpty set the secrets to nil if empty. // This is useful before rendering as JSON so the empty fields // will not be serialized. func (a *Admin) SetNilSecretsIfEmpty() { if a.Filters.TOTPConfig.Secret != nil && a.Filters.TOTPConfig.Secret.IsEmpty() { a.Filters.TOTPConfig.Secret = nil } } // HasPermission returns true if the admin has the specified permission func (a *Admin) HasPermission(perm string) bool { if slices.Contains(a.Permissions, PermAdminAny) { return true } return slices.Contains(a.Permissions, perm) } // HasPermissions returns true if the admin has all the specified permissions func (a *Admin) HasPermissions(perms ...string) bool { for _, perm := range perms { if !a.HasPermission(perm) { return false } } return len(perms) > 0 } // GetAllowedIPAsString returns the allowed IP as comma separated string func (a *Admin) GetAllowedIPAsString() string { return strings.Join(a.Filters.AllowList, ",") } // GetValidPerms returns the allowed admin permissions func (a *Admin) GetValidPerms() []string { return validAdminPerms } // CanManageMFA returns true if the admin can add a multi-factor authentication configuration func (a *Admin) CanManageMFA() bool { return len(mfa.GetAvailableTOTPConfigs()) > 0 } // GetSignature returns a signature for this admin. // It will change after an update func (a *Admin) GetSignature() string { return strconv.FormatInt(a.UpdatedAt, 10) } func (a *Admin) getACopy() Admin { a.SetEmptySecretsIfNil() permissions := make([]string, len(a.Permissions)) copy(permissions, a.Permissions) filters := AdminFilters{} filters.AllowList = make([]string, len(a.Filters.AllowList)) filters.AllowAPIKeyAuth = a.Filters.AllowAPIKeyAuth filters.RequirePasswordChange = a.Filters.RequirePasswordChange filters.RequireTwoFactor = a.Filters.RequireTwoFactor filters.TOTPConfig.Enabled = a.Filters.TOTPConfig.Enabled filters.TOTPConfig.ConfigName = a.Filters.TOTPConfig.ConfigName filters.TOTPConfig.Secret = a.Filters.TOTPConfig.Secret.Clone() copy(filters.AllowList, a.Filters.AllowList) filters.RecoveryCodes = make([]RecoveryCode, 0) for _, code := range a.Filters.RecoveryCodes { if code.Secret == nil { code.Secret = kms.NewEmptySecret() } filters.RecoveryCodes = append(filters.RecoveryCodes, RecoveryCode{ Secret: code.Secret.Clone(), Used: code.Used, }) } filters.Preferences = AdminPreferences{ HideUserPageSections: a.Filters.Preferences.HideUserPageSections, DefaultUsersExpiration: a.Filters.Preferences.DefaultUsersExpiration, } groups := make([]AdminGroupMapping, 0, len(a.Groups)) for _, g := range a.Groups { groups = append(groups, AdminGroupMapping{ Name: g.Name, Options: AdminGroupMappingOptions{ AddToUsersAs: g.Options.AddToUsersAs, }, }) } return Admin{ ID: a.ID, Status: a.Status, Username: a.Username, Password: a.Password, Email: a.Email, Permissions: permissions, Groups: groups, Filters: filters, AdditionalInfo: a.AdditionalInfo, Description: a.Description, LastLogin: a.LastLogin, CreatedAt: a.CreatedAt, UpdatedAt: a.UpdatedAt, Role: a.Role, } } func (a *Admin) setFromEnv() error { envUsername := strings.TrimSpace(os.Getenv("SFTPGO_DEFAULT_ADMIN_USERNAME")) envPassword := strings.TrimSpace(os.Getenv("SFTPGO_DEFAULT_ADMIN_PASSWORD")) if envUsername == "" || envPassword == "" { return errors.New(`to create the default admin you need to set the env vars "SFTPGO_DEFAULT_ADMIN_USERNAME" and "SFTPGO_DEFAULT_ADMIN_PASSWORD"`) } a.Username = envUsername a.Password = envPassword a.Status = 1 a.Permissions = []string{PermAdminAny} return nil } ================================================ FILE: internal/dataprovider/apikey.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "encoding/json" "fmt" "strings" "time" "github.com/alexedwards/argon2id" "golang.org/x/crypto/bcrypt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) // APIKeyScope defines the supported API key scopes type APIKeyScope int // Supported API key scopes const ( // the API key will be used for an admin APIKeyScopeAdmin APIKeyScope = iota + 1 // the API key will be used for a user APIKeyScopeUser ) // APIKey defines a SFTPGo API key. // API keys can be used as authentication alternative to short lived tokens // for REST API type APIKey struct { // Database unique identifier ID int64 `json:"-"` // Unique key identifier, used for key lookups. // The generated key is in the format `KeyID.hash(Key)` so we can split // and lookup by KeyID and then verify if the key matches the recorded hash KeyID string `json:"id"` // User friendly key name Name string `json:"name"` // we store the hash of the key, this is just like a password Key string `json:"key,omitempty"` Scope APIKeyScope `json:"scope"` CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` // 0 means never used LastUseAt int64 `json:"last_use_at,omitempty"` // 0 means never expire ExpiresAt int64 `json:"expires_at,omitempty"` Description string `json:"description,omitempty"` // Username associated with this API key. // If empty and the scope is APIKeyScopeUser the key is valid for any user User string `json:"user,omitempty"` // Admin username associated with this API key. // If empty and the scope is APIKeyScopeAdmin the key is valid for any admin Admin string `json:"admin,omitempty"` // these fields are for internal use userID int64 adminID int64 plainKey string } func (k *APIKey) getACopy() APIKey { return APIKey{ ID: k.ID, KeyID: k.KeyID, Name: k.Name, Key: k.Key, Scope: k.Scope, CreatedAt: k.CreatedAt, UpdatedAt: k.UpdatedAt, LastUseAt: k.LastUseAt, ExpiresAt: k.ExpiresAt, Description: k.Description, User: k.User, Admin: k.Admin, userID: k.userID, adminID: k.adminID, } } // RenderAsJSON implements the renderer interface used within plugins func (k *APIKey) RenderAsJSON(reload bool) ([]byte, error) { if reload { apiKey, err := provider.apiKeyExists(k.KeyID) if err != nil { providerLog(logger.LevelError, "unable to reload api key before rendering as json: %v", err) return nil, err } apiKey.HideConfidentialData() return json.Marshal(apiKey) } k.HideConfidentialData() return json.Marshal(k) } // HideConfidentialData hides API key confidential data func (k *APIKey) HideConfidentialData() { k.Key = "" } func (k *APIKey) hashKey() error { if k.Key != "" && !util.IsStringPrefixInSlice(k.Key, internalHashPwdPrefixes) { if config.PasswordHashing.Algo == HashingAlgoBcrypt { hashed, err := bcrypt.GenerateFromPassword([]byte(k.Key), config.PasswordHashing.BcryptOptions.Cost) if err != nil { return err } k.Key = util.BytesToString(hashed) } else { hashed, err := argon2id.CreateHash(k.Key, argon2Params) if err != nil { return err } k.Key = hashed } } return nil } func (k *APIKey) generateKey() { if k.KeyID != "" || k.Key != "" { return } k.KeyID = util.GenerateUniqueID() k.Key = util.GenerateUniqueID() k.plainKey = k.Key } // DisplayKey returns the key to show to the user func (k *APIKey) DisplayKey() string { return fmt.Sprintf("%v.%v", k.KeyID, k.plainKey) } func (k *APIKey) validate() error { if k.Name == "" { return util.NewValidationError("name is mandatory") } if !util.IsNameValid(k.Name) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if k.Scope != APIKeyScopeAdmin && k.Scope != APIKeyScopeUser { return util.NewValidationError(fmt.Sprintf("invalid scope: %v", k.Scope)) } k.generateKey() if err := k.hashKey(); err != nil { return err } if k.User != "" && k.Admin != "" { return util.NewValidationError("an API key can be related to a user or an admin, not both") } if k.Scope == APIKeyScopeAdmin { k.User = "" } if k.Scope == APIKeyScopeUser { k.Admin = "" } if k.User != "" { _, err := provider.userExists(k.User, "") if err != nil { return util.NewValidationError(fmt.Sprintf("unable to check API key user %v: %v", k.User, err)) } } if k.Admin != "" { _, err := provider.adminExists(k.Admin) if err != nil { return util.NewValidationError(fmt.Sprintf("unable to check API key admin %v: %v", k.Admin, err)) } } return nil } // Authenticate tries to authenticate the provided plain key func (k *APIKey) Authenticate(plainKey string) error { if k.ExpiresAt > 0 && k.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) { return fmt.Errorf("API key %q is expired, expiration timestamp: %v current timestamp: %v", k.KeyID, k.ExpiresAt, util.GetTimeAsMsSinceEpoch(time.Now())) } if config.PasswordCaching { found, match := cachedAPIKeys.Check(k.KeyID, plainKey, k.Key) if found { if !match { return ErrInvalidCredentials } return nil } } if strings.HasPrefix(k.Key, bcryptPwdPrefix) { if err := bcrypt.CompareHashAndPassword([]byte(k.Key), []byte(plainKey)); err != nil { return ErrInvalidCredentials } } else if strings.HasPrefix(k.Key, argonPwdPrefix) { match, err := argon2id.ComparePasswordAndHash(plainKey, k.Key) if err != nil || !match { return ErrInvalidCredentials } } cachedAPIKeys.Add(k.KeyID, plainKey, k.Key) return nil } ================================================ FILE: internal/dataprovider/bolt.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !nobolt package dataprovider import ( "bytes" "crypto/x509" "encoding/json" "errors" "fmt" "net/netip" "path/filepath" "slices" "sort" "strconv" "time" bolt "go.etcd.io/bbolt" bolterrors "go.etcd.io/bbolt/errors" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( boltDatabaseVersion = 34 ) var ( usersBucket = []byte("users") groupsBucket = []byte("groups") foldersBucket = []byte("folders") adminsBucket = []byte("admins") apiKeysBucket = []byte("api_keys") sharesBucket = []byte("shares") actionsBucket = []byte("events_actions") rulesBucket = []byte("events_rules") rolesBucket = []byte("roles") ipListsBucket = []byte("ip_lists") configsBucket = []byte("configs") dbVersionBucket = []byte("db_version") dbVersionKey = []byte("version") configsKey = []byte("configs") boltBuckets = [][]byte{usersBucket, groupsBucket, foldersBucket, adminsBucket, apiKeysBucket, sharesBucket, actionsBucket, rulesBucket, rolesBucket, ipListsBucket, configsBucket, dbVersionBucket} ) // BoltProvider defines the auth provider for bolt key/value store type BoltProvider struct { dbHandle *bolt.DB } func init() { version.AddFeature("+bolt") } func initializeBoltProvider(basePath string) error { var err error dbPath := config.Name if !util.IsFileInputValid(dbPath) { return fmt.Errorf("invalid database path: %q", dbPath) } if !filepath.IsAbs(dbPath) { dbPath = filepath.Join(basePath, dbPath) } dbHandle, err := bolt.Open(dbPath, 0600, &bolt.Options{ NoGrowSync: false, FreelistType: bolt.FreelistArrayType, Timeout: 5 * time.Second}) if err == nil { providerLog(logger.LevelDebug, "bolt key store handle created") for _, bucket := range boltBuckets { if err := dbHandle.Update(func(tx *bolt.Tx) error { _, e := tx.CreateBucketIfNotExists(bucket) return e }); err != nil { providerLog(logger.LevelError, "error creating bucket %q: %v", string(bucket), err) } } provider = &BoltProvider{dbHandle: dbHandle} } else { providerLog(logger.LevelError, "error creating bolt key/value store handler: %v", err) } return err } func (p *BoltProvider) checkAvailability() error { _, err := getBoltDatabaseVersion(p.dbHandle) return err } func (p *BoltProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { var user User if tlsCert == nil { return user, errors.New("TLS certificate cannot be null or empty") } user, err := p.userExists(username, "") if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, err } return checkUserAndTLSCertificate(&user, protocol, tlsCert) } func (p *BoltProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { user, err := p.userExists(username, "") if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, err } return checkUserAndPass(&user, password, ip, protocol) } func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { admin, err := p.adminExists(username) if err != nil { providerLog(logger.LevelWarn, "error authenticating admin %q: %v", username, err) return admin, err } err = admin.checkUserAndPass(password, ip) return admin, err } func (p *BoltProvider) validateUserAndPubKey(username string, pubKey []byte, isSSHCert bool) (User, string, error) { var user User if len(pubKey) == 0 { return user, "", errors.New("credentials cannot be null or empty") } user, err := p.userExists(username, "") if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, "", err } return checkUserAndPubKey(&user, pubKey, isSSHCert) } func (p *BoltProvider) updateAPIKeyLastUse(keyID string) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getAPIKeysBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(keyID)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("key %q does not exist, unable to update last use", keyID)) } var apiKey APIKey err = json.Unmarshal(u, &apiKey) if err != nil { return err } apiKey.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(apiKey) if err != nil { return err } err = bucket.Put([]byte(keyID), buf) if err != nil { providerLog(logger.LevelWarn, "error updating last use for key %q: %v", keyID, err) return err } providerLog(logger.LevelDebug, "last use updated for key %q", keyID) return nil }) } func (p *BoltProvider) getAdminSignature(username string) (string, error) { var updatedAt int64 err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } u := bucket.Get([]byte(username)) var admin Admin err = json.Unmarshal(u, &admin) if err != nil { return err } updatedAt = admin.UpdatedAt return nil }) if err != nil { return "", err } return strconv.FormatInt(updatedAt, 10), nil } func (p *BoltProvider) getUserSignature(username string) (string, error) { var updatedAt int64 err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } u := bucket.Get([]byte(username)) var user User err = json.Unmarshal(u, &user) if err != nil { return err } updatedAt = user.UpdatedAt return nil }) if err != nil { return "", err } return strconv.FormatInt(updatedAt, 10), nil } func (p *BoltProvider) setUpdatedAt(username string) { p.dbHandle.Update(func(tx *bolt.Tx) error { //nolint:errcheck bucket, err := p.getUsersBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to update updated at", username)) } var user User err = json.Unmarshal(u, &user) if err != nil { return err } user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(user) if err != nil { return err } err = bucket.Put([]byte(username), buf) if err == nil { providerLog(logger.LevelDebug, "updated at set for user %q", username) setLastUserUpdate() } else { providerLog(logger.LevelWarn, "error setting updated_at for user %q: %v", username, err) } return err }) } func (p *BoltProvider) updateLastLogin(username string) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to update last login", username)) } var user User err = json.Unmarshal(u, &user) if err != nil { return err } user.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(user) if err != nil { return err } err = bucket.Put([]byte(username), buf) if err != nil { providerLog(logger.LevelWarn, "error updating last login for user %q: %v", username, err) } else { providerLog(logger.LevelDebug, "last login updated for user %q", username) } return err }) } func (p *BoltProvider) updateAdminLastLogin(username string) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } var a []byte if a = bucket.Get([]byte(username)); a == nil { return util.NewRecordNotFoundError(fmt.Sprintf("admin %q does not exist, unable to update last login", username)) } var admin Admin err = json.Unmarshal(a, &admin) if err != nil { return err } admin.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(admin) if err != nil { return err } err = bucket.Put([]byte(username), buf) if err == nil { providerLog(logger.LevelDebug, "last login updated for admin %q", username) return err } providerLog(logger.LevelWarn, "error updating last login for admin %q: %v", username, err) return err }) } func (p *BoltProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to update transfer quota", username)) } var user User err = json.Unmarshal(u, &user) if err != nil { return err } if !reset { user.UsedUploadDataTransfer += uploadSize user.UsedDownloadDataTransfer += downloadSize } else { user.UsedUploadDataTransfer = uploadSize user.UsedDownloadDataTransfer = downloadSize } user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(user) if err != nil { return err } err = bucket.Put([]byte(username), buf) providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %v dl increment: %v is reset? %v", username, uploadSize, downloadSize, reset) return err }) } func (p *BoltProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to update quota", username)) } var user User err = json.Unmarshal(u, &user) if err != nil { return err } if reset { user.UsedQuotaSize = sizeAdd user.UsedQuotaFiles = filesAdd } else { user.UsedQuotaSize += sizeAdd user.UsedQuotaFiles += filesAdd } user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(user) if err != nil { return err } err = bucket.Put([]byte(username), buf) providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %v size increment: %v is reset? %v", username, filesAdd, sizeAdd, reset) return err }) } func (p *BoltProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { user, err := p.userExists(username, "") if err != nil { providerLog(logger.LevelError, "unable to get quota for user %v error: %v", username, err) return 0, 0, 0, 0, err } return user.UsedQuotaFiles, user.UsedQuotaSize, user.UsedUploadDataTransfer, user.UsedDownloadDataTransfer, err } func (p *BoltProvider) adminExists(username string) (Admin, error) { var admin Admin err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } a := bucket.Get([]byte(username)) if a == nil { return util.NewRecordNotFoundError(fmt.Sprintf("admin %v does not exist", username)) } return json.Unmarshal(a, &admin) }) return admin, err } func (p *BoltProvider) addAdmin(admin *Admin) error { err := admin.validate() if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } groupBucket, err := p.getGroupsBucket(tx) if err != nil { return err } rolesBucket, err := p.getRolesBucket(tx) if err != nil { return err } if a := bucket.Get([]byte(admin.Username)); a != nil { return util.NewI18nError( fmt.Errorf("%w: admin %q already exists", ErrDuplicatedKey, admin.Username), util.I18nErrorDuplicatedUsername, ) } id, err := bucket.NextSequence() if err != nil { return err } admin.ID = int64(id) admin.LastLogin = 0 admin.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) for idx := range admin.Groups { err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name, groupBucket) if err != nil { return err } } if err = p.addAdminToRole(admin.Username, admin.Role, rolesBucket); err != nil { return err } buf, err := json.Marshal(admin) if err != nil { return err } return bucket.Put([]byte(admin.Username), buf) }) } func (p *BoltProvider) updateAdmin(admin *Admin) error { err := admin.validate() if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } groupBucket, err := p.getGroupsBucket(tx) if err != nil { return err } rolesBucket, err := p.getRolesBucket(tx) if err != nil { return err } var a []byte if a = bucket.Get([]byte(admin.Username)); a == nil { return util.NewRecordNotFoundError(fmt.Sprintf("admin %v does not exist", admin.Username)) } var oldAdmin Admin err = json.Unmarshal(a, &oldAdmin) if err != nil { return err } if err = p.removeAdminFromRole(oldAdmin.Username, oldAdmin.Role, rolesBucket); err != nil { return err } for idx := range oldAdmin.Groups { err = p.removeAdminFromGroupMapping(oldAdmin.Username, oldAdmin.Groups[idx].Name, groupBucket) if err != nil { return err } } if err = p.addAdminToRole(admin.Username, admin.Role, rolesBucket); err != nil { return err } for idx := range admin.Groups { err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name, groupBucket) if err != nil { return err } } admin.ID = oldAdmin.ID admin.CreatedAt = oldAdmin.CreatedAt admin.LastLogin = oldAdmin.LastLogin admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(admin) if err != nil { return err } return bucket.Put([]byte(admin.Username), buf) }) } func (p *BoltProvider) deleteAdmin(admin Admin) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } var a []byte if a = bucket.Get([]byte(admin.Username)); a == nil { return util.NewRecordNotFoundError(fmt.Sprintf("admin %v does not exist", admin.Username)) } var oldAdmin Admin err = json.Unmarshal(a, &oldAdmin) if err != nil { return err } if len(oldAdmin.Groups) > 0 { groupBucket, err := p.getGroupsBucket(tx) if err != nil { return err } for idx := range oldAdmin.Groups { err = p.removeAdminFromGroupMapping(oldAdmin.Username, oldAdmin.Groups[idx].Name, groupBucket) if err != nil { return err } } } if oldAdmin.Role != "" { rolesBucket, err := p.getRolesBucket(tx) if err != nil { return err } if err = p.removeAdminFromRole(oldAdmin.Username, oldAdmin.Role, rolesBucket); err != nil { return err } } if err := p.deleteRelatedAPIKey(tx, admin.Username, APIKeyScopeAdmin); err != nil { return err } return bucket.Delete([]byte(admin.Username)) }) } func (p *BoltProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { admins := make([]Admin, 0, limit) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } cursor := bucket.Cursor() itNum := 0 if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { itNum++ if itNum <= offset { continue } var admin Admin err = json.Unmarshal(v, &admin) if err != nil { return err } admin.HideConfidentialData() admins = append(admins, admin) if len(admins) >= limit { break } } } else { for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { itNum++ if itNum <= offset { continue } var admin Admin err = json.Unmarshal(v, &admin) if err != nil { return err } admin.HideConfidentialData() admins = append(admins, admin) if len(admins) >= limit { break } } } return err }) return admins, err } func (p *BoltProvider) dumpAdmins() ([]Admin, error) { admins := make([]Admin, 0, 30) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var admin Admin err = json.Unmarshal(v, &admin) if err != nil { return err } admins = append(admins, admin) } return err }) return admins, err } func (p *BoltProvider) userExists(username, role string) (User, error) { var user User err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } u := bucket.Get([]byte(username)) if u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } user, err = p.joinUserAndFolders(u, foldersBucket) if err != nil { return err } if !user.hasRole(role) { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } return nil }) return user, err } func (p *BoltProvider) addUser(user *User) error { err := ValidateUser(user) if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } groupBucket, err := p.getGroupsBucket(tx) if err != nil { return err } rolesBucket, err := p.getRolesBucket(tx) if err != nil { return err } if u := bucket.Get([]byte(user.Username)); u != nil { return util.NewI18nError( fmt.Errorf("%w: username %v already exists", ErrDuplicatedKey, user.Username), util.I18nErrorDuplicatedUsername, ) } id, err := bucket.NextSequence() if err != nil { return err } user.ID = int64(id) user.LastQuotaUpdate = 0 user.UsedQuotaSize = 0 user.UsedQuotaFiles = 0 user.UsedUploadDataTransfer = 0 user.UsedDownloadDataTransfer = 0 user.LastLogin = 0 user.FirstDownload = 0 user.FirstUpload = 0 user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) if err := p.addUserToRole(user.Username, user.Role, rolesBucket); err != nil { return err } for idx := range user.VirtualFolders { err = p.addRelationToFolderMapping(user.VirtualFolders[idx].Name, user, nil, foldersBucket) if err != nil { return err } } for idx := range user.Groups { err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name, groupBucket) if err != nil { return err } } buf, err := json.Marshal(user) if err != nil { return err } return bucket.Put([]byte(user.Username), buf) }) } func (p *BoltProvider) updateUser(user *User) error { err := ValidateUser(user) if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(user.Username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", user.Username)) } var oldUser User err = json.Unmarshal(u, &oldUser) if err != nil { return err } if err = p.updateUserRelations(tx, user, oldUser); err != nil { return err } user.ID = oldUser.ID user.LastQuotaUpdate = oldUser.LastQuotaUpdate user.UsedQuotaSize = oldUser.UsedQuotaSize user.UsedQuotaFiles = oldUser.UsedQuotaFiles user.UsedUploadDataTransfer = oldUser.UsedUploadDataTransfer user.UsedDownloadDataTransfer = oldUser.UsedDownloadDataTransfer user.LastLogin = oldUser.LastLogin user.FirstDownload = oldUser.FirstDownload user.FirstUpload = oldUser.FirstUpload user.CreatedAt = oldUser.CreatedAt user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(user) if err != nil { return err } err = bucket.Put([]byte(user.Username), buf) if err == nil { setLastUserUpdate() } return err }) } func (p *BoltProvider) deleteUser(user User, _ bool) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } groupBucket, err := p.getGroupsBucket(tx) if err != nil { return err } rolesBucket, err := p.getRolesBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(user.Username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", user.Username)) } var oldUser User err = json.Unmarshal(u, &oldUser) if err != nil { return err } if err := p.removeUserFromRole(oldUser.Username, oldUser.Role, rolesBucket); err != nil { return err } for idx := range oldUser.VirtualFolders { err = p.removeRelationFromFolderMapping(oldUser.VirtualFolders[idx], oldUser.Username, "", foldersBucket) if err != nil { return err } } for idx := range oldUser.Groups { err = p.removeUserFromGroupMapping(oldUser.Username, oldUser.Groups[idx].Name, groupBucket) if err != nil { return err } } if err := p.deleteRelatedAPIKey(tx, user.Username, APIKeyScopeUser); err != nil { return err } if err := p.deleteRelatedShares(tx, user.Username); err != nil { return err } return bucket.Delete([]byte(user.Username)) }) } func (p *BoltProvider) updateUserPassword(username, password string) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } var user User err = json.Unmarshal(u, &user) if err != nil { return err } user.Password = password user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(user) if err != nil { return err } return bucket.Put([]byte(username), buf) }) } func (p *BoltProvider) dumpUsers() ([]User, error) { users := make([]User, 0, 100) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { user, err := p.joinUserAndFolders(v, foldersBucket) if err != nil { return err } users = append(users, user) } return err }) return users, err } func (p *BoltProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { if getLastUserUpdate() < after { return nil, nil } users := make([]User, 0, 10) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } groupsBucket, err := p.getGroupsBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var user User err := json.Unmarshal(v, &user) if err != nil { return err } if user.UpdatedAt < after { continue } if len(user.VirtualFolders) > 0 { var folders []vfs.VirtualFolder for idx := range user.VirtualFolders { folder := &user.VirtualFolders[idx] baseFolder, err := p.folderExistsInternal(folder.Name, foldersBucket) if err != nil { continue } folder.BaseVirtualFolder = baseFolder folders = append(folders, *folder) } user.VirtualFolders = folders } if len(user.Groups) > 0 { groupMapping := make(map[string]Group) for idx := range user.Groups { group, err := p.groupExistsInternal(user.Groups[idx].Name, groupsBucket) if err != nil { continue } groupMapping[group.Name] = group } user.applyGroupSettings(groupMapping) } user.SetEmptySecretsIfNil() users = append(users, user) } return err }) return users, err } func (p *BoltProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { users := make([]User, 0, 10) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } groupsBucket, err := p.getGroupsBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var user User err := json.Unmarshal(v, &user) if err != nil { return err } if needFolders, ok := toFetch[user.Username]; ok { if needFolders && len(user.VirtualFolders) > 0 { var folders []vfs.VirtualFolder for idx := range user.VirtualFolders { folder := &user.VirtualFolders[idx] baseFolder, err := p.folderExistsInternal(folder.Name, foldersBucket) if err != nil { continue } folder.BaseVirtualFolder = baseFolder folders = append(folders, *folder) } user.VirtualFolders = folders } if len(user.Groups) > 0 { groupMapping := make(map[string]Group) for idx := range user.Groups { group, err := p.groupExistsInternal(user.Groups[idx].Name, groupsBucket) if err != nil { continue } groupMapping[group.Name] = group } user.applyGroupSettings(groupMapping) } user.SetEmptySecretsIfNil() user.PrepareForRendering() users = append(users, user) } } return nil }) return users, err } func (p *BoltProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { users := make([]User, 0, limit) var err error if limit <= 0 { return users, err } err = p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } cursor := bucket.Cursor() itNum := 0 if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { itNum++ if itNum <= offset { continue } user, err := p.joinUserAndFolders(v, foldersBucket) if err != nil { return err } if !user.hasRole(role) { continue } user.PrepareForRendering() users = append(users, user) if len(users) >= limit { break } } } else { for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { itNum++ if itNum <= offset { continue } user, err := p.joinUserAndFolders(v, foldersBucket) if err != nil { return err } if !user.hasRole(role) { continue } user.PrepareForRendering() users = append(users, user) if len(users) >= limit { break } } } return err }) return users, err } func (p *BoltProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { folders := make([]vfs.BaseVirtualFolder, 0, 50) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getFoldersBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var folder vfs.BaseVirtualFolder err = json.Unmarshal(v, &folder) if err != nil { return err } folders = append(folders, folder) } return err }) return folders, err } func (p *BoltProvider) getFolders(limit, offset int, order string, _ bool) ([]vfs.BaseVirtualFolder, error) { folders := make([]vfs.BaseVirtualFolder, 0, limit) var err error if limit <= 0 { return folders, err } err = p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getFoldersBucket(tx) if err != nil { return err } cursor := bucket.Cursor() itNum := 0 if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { itNum++ if itNum <= offset { continue } var folder vfs.BaseVirtualFolder err = json.Unmarshal(v, &folder) if err != nil { return err } folder.PrepareForRendering() folders = append(folders, folder) if len(folders) >= limit { break } } } else { for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { itNum++ if itNum <= offset { continue } var folder vfs.BaseVirtualFolder err = json.Unmarshal(v, &folder) if err != nil { return err } folder.PrepareForRendering() folders = append(folders, folder) if len(folders) >= limit { break } } } return err }) return folders, err } func (p *BoltProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { var folder vfs.BaseVirtualFolder err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getFoldersBucket(tx) if err != nil { return err } folder, err = p.folderExistsInternal(name, bucket) return err }) return folder, err } func (p *BoltProvider) addFolder(folder *vfs.BaseVirtualFolder) error { err := ValidateFolder(folder) if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getFoldersBucket(tx) if err != nil { return err } if f := bucket.Get([]byte(folder.Name)); f != nil { return util.NewI18nError( fmt.Errorf("%w: folder %q already exists", ErrDuplicatedKey, folder.Name), util.I18nErrorDuplicatedUsername, ) } folder.Users = nil folder.Groups = nil return p.addFolderInternal(*folder, bucket) }) } func (p *BoltProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { err := ValidateFolder(folder) if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getFoldersBucket(tx) if err != nil { return err } var f []byte if f = bucket.Get([]byte(folder.Name)); f == nil { return util.NewRecordNotFoundError(fmt.Sprintf("folder %v does not exist", folder.Name)) } var oldFolder vfs.BaseVirtualFolder err = json.Unmarshal(f, &oldFolder) if err != nil { return err } folder.ID = oldFolder.ID folder.LastQuotaUpdate = oldFolder.LastQuotaUpdate folder.UsedQuotaFiles = oldFolder.UsedQuotaFiles folder.UsedQuotaSize = oldFolder.UsedQuotaSize folder.Users = oldFolder.Users folder.Groups = oldFolder.Groups buf, err := json.Marshal(folder) if err != nil { return err } return bucket.Put([]byte(folder.Name), buf) }) } func (p *BoltProvider) deleteFolderMappings(folder vfs.BaseVirtualFolder, usersBucket, groupsBucket *bolt.Bucket) error { for _, username := range folder.Users { var u []byte if u = usersBucket.Get([]byte(username)); u == nil { continue } var user User err := json.Unmarshal(u, &user) if err != nil { return err } var folders []vfs.VirtualFolder for _, userFolder := range user.VirtualFolders { if folder.Name != userFolder.Name { folders = append(folders, userFolder) } } user.VirtualFolders = folders buf, err := json.Marshal(user) if err != nil { return err } err = usersBucket.Put([]byte(user.Username), buf) if err != nil { return err } } for _, groupname := range folder.Groups { var u []byte if u = groupsBucket.Get([]byte(groupname)); u == nil { continue } var group Group err := json.Unmarshal(u, &group) if err != nil { return err } var folders []vfs.VirtualFolder for _, groupFolder := range group.VirtualFolders { if folder.Name != groupFolder.Name { folders = append(folders, groupFolder) } } group.VirtualFolders = folders buf, err := json.Marshal(group) if err != nil { return err } err = groupsBucket.Put([]byte(group.Name), buf) if err != nil { return err } } return nil } func (p *BoltProvider) deleteFolder(baseFolder vfs.BaseVirtualFolder) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getFoldersBucket(tx) if err != nil { return err } usersBucket, err := p.getUsersBucket(tx) if err != nil { return err } groupsBucket, err := p.getGroupsBucket(tx) if err != nil { return err } var f []byte if f = bucket.Get([]byte(baseFolder.Name)); f == nil { return util.NewRecordNotFoundError(fmt.Sprintf("folder %v does not exist", baseFolder.Name)) } var folder vfs.BaseVirtualFolder err = json.Unmarshal(f, &folder) if err != nil { return err } if err = p.deleteFolderMappings(folder, usersBucket, groupsBucket); err != nil { return err } return bucket.Delete([]byte(folder.Name)) }) } func (p *BoltProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getFoldersBucket(tx) if err != nil { return err } var f []byte if f = bucket.Get([]byte(name)); f == nil { return util.NewRecordNotFoundError(fmt.Sprintf("folder %q does not exist, unable to update quota", name)) } var folder vfs.BaseVirtualFolder err = json.Unmarshal(f, &folder) if err != nil { return err } if reset { folder.UsedQuotaSize = sizeAdd folder.UsedQuotaFiles = filesAdd } else { folder.UsedQuotaSize += sizeAdd folder.UsedQuotaFiles += filesAdd } folder.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(folder) if err != nil { return err } return bucket.Put([]byte(folder.Name), buf) }) } func (p *BoltProvider) getUsedFolderQuota(name string) (int, int64, error) { folder, err := p.getFolderByName(name) if err != nil { providerLog(logger.LevelError, "unable to get quota for folder %q error: %v", name, err) return 0, 0, err } return folder.UsedQuotaFiles, folder.UsedQuotaSize, err } func (p *BoltProvider) getGroups(limit, offset int, order string, _ bool) ([]Group, error) { groups := make([]Group, 0, limit) var err error if limit <= 0 { return groups, err } err = p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getGroupsBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } cursor := bucket.Cursor() itNum := 0 if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { itNum++ if itNum <= offset { continue } var group Group group, err = p.joinGroupAndFolders(v, foldersBucket) if err != nil { return err } group.PrepareForRendering() groups = append(groups, group) if len(groups) >= limit { break } } } else { for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { itNum++ if itNum <= offset { continue } var group Group group, err = p.joinGroupAndFolders(v, foldersBucket) if err != nil { return err } group.PrepareForRendering() groups = append(groups, group) if len(groups) >= limit { break } } } return err }) return groups, err } func (p *BoltProvider) getGroupsWithNames(names []string) ([]Group, error) { var groups []Group err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getGroupsBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } for _, name := range names { g := bucket.Get([]byte(name)) if g == nil { continue } group, err := p.joinGroupAndFolders(g, foldersBucket) if err != nil { return err } groups = append(groups, group) } return nil }) return groups, err } func (p *BoltProvider) getUsersInGroups(names []string) ([]string, error) { var usernames []string err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getGroupsBucket(tx) if err != nil { return err } for _, name := range names { g := bucket.Get([]byte(name)) if g == nil { continue } var group Group err := json.Unmarshal(g, &group) if err != nil { return err } usernames = append(usernames, group.Users...) } return nil }) return usernames, err } func (p *BoltProvider) groupExists(name string) (Group, error) { var group Group err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getGroupsBucket(tx) if err != nil { return err } g := bucket.Get([]byte(name)) if g == nil { return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", name)) } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } group, err = p.joinGroupAndFolders(g, foldersBucket) return err }) return group, err } func (p *BoltProvider) addGroup(group *Group) error { if err := group.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getGroupsBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } if u := bucket.Get([]byte(group.Name)); u != nil { return util.NewI18nError( fmt.Errorf("%w: group %q already exists", ErrDuplicatedKey, group.Name), util.I18nErrorDuplicatedUsername, ) } id, err := bucket.NextSequence() if err != nil { return err } group.ID = int64(id) group.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.Users = nil group.Admins = nil for idx := range group.VirtualFolders { err = p.addRelationToFolderMapping(group.VirtualFolders[idx].Name, nil, group, foldersBucket) if err != nil { return err } } buf, err := json.Marshal(group) if err != nil { return err } return bucket.Put([]byte(group.Name), buf) }) } func (p *BoltProvider) updateGroup(group *Group) error { if err := group.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getGroupsBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } var g []byte if g = bucket.Get([]byte(group.Name)); g == nil { return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", group.Name)) } var oldGroup Group err = json.Unmarshal(g, &oldGroup) if err != nil { return err } for idx := range oldGroup.VirtualFolders { err = p.removeRelationFromFolderMapping(oldGroup.VirtualFolders[idx], "", oldGroup.Name, foldersBucket) if err != nil { return err } } for idx := range group.VirtualFolders { err = p.addRelationToFolderMapping(group.VirtualFolders[idx].Name, nil, group, foldersBucket) if err != nil { return err } } group.ID = oldGroup.ID group.CreatedAt = oldGroup.CreatedAt group.Users = oldGroup.Users group.Admins = oldGroup.Admins group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(group) if err != nil { return err } return bucket.Put([]byte(group.Name), buf) }) } func (p *BoltProvider) deleteGroup(group Group) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getGroupsBucket(tx) if err != nil { return err } var g []byte if g = bucket.Get([]byte(group.Name)); g == nil { return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", group.Name)) } var oldGroup Group err = json.Unmarshal(g, &oldGroup) if err != nil { return err } if len(oldGroup.Users) > 0 { return util.NewValidationError(fmt.Sprintf("the group %q is referenced, it cannot be removed", oldGroup.Name)) } if len(oldGroup.VirtualFolders) > 0 { foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } for idx := range oldGroup.VirtualFolders { err = p.removeRelationFromFolderMapping(oldGroup.VirtualFolders[idx], "", oldGroup.Name, foldersBucket) if err != nil { return err } } } if len(oldGroup.Admins) > 0 { adminsBucket, err := p.getAdminsBucket(tx) if err != nil { return err } for idx := range oldGroup.Admins { err = p.removeGroupFromAdminMapping(oldGroup.Name, oldGroup.Admins[idx], adminsBucket) if err != nil { return err } } } return bucket.Delete([]byte(group.Name)) }) } func (p *BoltProvider) dumpGroups() ([]Group, error) { groups := make([]Group, 0, 50) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getGroupsBucket(tx) if err != nil { return err } foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { group, err := p.joinGroupAndFolders(v, foldersBucket) if err != nil { return err } groups = append(groups, group) } return err }) return groups, err } func (p *BoltProvider) apiKeyExists(keyID string) (APIKey, error) { var apiKey APIKey err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getAPIKeysBucket(tx) if err != nil { return err } k := bucket.Get([]byte(keyID)) if k == nil { return util.NewRecordNotFoundError(fmt.Sprintf("API key %v does not exist", keyID)) } return json.Unmarshal(k, &apiKey) }) return apiKey, err } func (p *BoltProvider) addAPIKey(apiKey *APIKey) error { err := apiKey.validate() if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getAPIKeysBucket(tx) if err != nil { return err } if a := bucket.Get([]byte(apiKey.KeyID)); a != nil { return fmt.Errorf("API key %v already exists", apiKey.KeyID) } id, err := bucket.NextSequence() if err != nil { return err } apiKey.ID = int64(id) apiKey.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) apiKey.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) apiKey.LastUseAt = 0 if apiKey.User != "" { if err := p.userExistsInternal(tx, apiKey.User); err != nil { return fmt.Errorf("%w: related user %q does not exists", ErrForeignKeyViolated, apiKey.User) } } if apiKey.Admin != "" { if err := p.adminExistsInternal(tx, apiKey.Admin); err != nil { return fmt.Errorf("%w: related admin %q does not exists", ErrForeignKeyViolated, apiKey.Admin) } } buf, err := json.Marshal(apiKey) if err != nil { return err } return bucket.Put([]byte(apiKey.KeyID), buf) }) } func (p *BoltProvider) updateAPIKey(apiKey *APIKey) error { err := apiKey.validate() if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getAPIKeysBucket(tx) if err != nil { return err } var a []byte if a = bucket.Get([]byte(apiKey.KeyID)); a == nil { return util.NewRecordNotFoundError(fmt.Sprintf("API key %v does not exist", apiKey.KeyID)) } var oldAPIKey APIKey err = json.Unmarshal(a, &oldAPIKey) if err != nil { return err } apiKey.ID = oldAPIKey.ID apiKey.KeyID = oldAPIKey.KeyID apiKey.Key = oldAPIKey.Key apiKey.CreatedAt = oldAPIKey.CreatedAt apiKey.LastUseAt = oldAPIKey.LastUseAt apiKey.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) if apiKey.User != "" { if err := p.userExistsInternal(tx, apiKey.User); err != nil { return fmt.Errorf("%w: related user %q does not exists", ErrForeignKeyViolated, apiKey.User) } } if apiKey.Admin != "" { if err := p.adminExistsInternal(tx, apiKey.Admin); err != nil { return fmt.Errorf("%w: related admin %q does not exists", ErrForeignKeyViolated, apiKey.Admin) } } buf, err := json.Marshal(apiKey) if err != nil { return err } return bucket.Put([]byte(apiKey.KeyID), buf) }) } func (p *BoltProvider) deleteAPIKey(apiKey APIKey) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getAPIKeysBucket(tx) if err != nil { return err } if bucket.Get([]byte(apiKey.KeyID)) == nil { return util.NewRecordNotFoundError(fmt.Sprintf("API key %v does not exist", apiKey.KeyID)) } return bucket.Delete([]byte(apiKey.KeyID)) }) } func (p *BoltProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { apiKeys := make([]APIKey, 0, limit) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getAPIKeysBucket(tx) if err != nil { return err } cursor := bucket.Cursor() itNum := 0 if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { itNum++ if itNum <= offset { continue } var apiKey APIKey err = json.Unmarshal(v, &apiKey) if err != nil { return err } apiKey.HideConfidentialData() apiKeys = append(apiKeys, apiKey) if len(apiKeys) >= limit { break } } return nil } for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { itNum++ if itNum <= offset { continue } var apiKey APIKey err = json.Unmarshal(v, &apiKey) if err != nil { return err } apiKey.HideConfidentialData() apiKeys = append(apiKeys, apiKey) if len(apiKeys) >= limit { break } } return nil }) return apiKeys, err } func (p *BoltProvider) dumpAPIKeys() ([]APIKey, error) { apiKeys := make([]APIKey, 0, 30) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getAPIKeysBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var apiKey APIKey err = json.Unmarshal(v, &apiKey) if err != nil { return err } apiKeys = append(apiKeys, apiKey) } return err }) return apiKeys, err } func (p *BoltProvider) shareExists(shareID, username string) (Share, error) { var share Share err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getSharesBucket(tx) if err != nil { return err } s := bucket.Get([]byte(shareID)) if s == nil { return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", shareID)) } if err := json.Unmarshal(s, &share); err != nil { return err } if username != "" && share.Username != username { return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", shareID)) } return nil }) return share, err } func (p *BoltProvider) addShare(share *Share) error { err := share.validate() if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getSharesBucket(tx) if err != nil { return err } if a := bucket.Get([]byte(share.ShareID)); a != nil { return fmt.Errorf("share %q already exists", share.ShareID) } id, err := bucket.NextSequence() if err != nil { return err } share.ID = int64(id) if !share.IsRestore { share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) share.UpdatedAt = share.CreatedAt share.LastUseAt = 0 share.UsedTokens = 0 } if share.CreatedAt == 0 { share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) } if share.UpdatedAt == 0 { share.UpdatedAt = share.CreatedAt } if err := p.userExistsInternal(tx, share.Username); err != nil { return util.NewValidationError(fmt.Sprintf("related user %q does not exists", share.Username)) } buf, err := json.Marshal(share) if err != nil { return err } return bucket.Put([]byte(share.ShareID), buf) }) } func (p *BoltProvider) updateShare(share *Share) error { if err := share.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getSharesBucket(tx) if err != nil { return err } var s []byte if s = bucket.Get([]byte(share.ShareID)); s == nil { return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", share.ShareID)) } var oldObject Share if err = json.Unmarshal(s, &oldObject); err != nil { return err } if oldObject.Username != share.Username { return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", share.ShareID)) } share.ID = oldObject.ID share.ShareID = oldObject.ShareID if !share.IsRestore { share.UsedTokens = oldObject.UsedTokens share.CreatedAt = oldObject.CreatedAt share.LastUseAt = oldObject.LastUseAt share.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) } if share.CreatedAt == 0 { share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) } if share.UpdatedAt == 0 { share.UpdatedAt = share.CreatedAt } if err := p.userExistsInternal(tx, share.Username); err != nil { return util.NewValidationError(fmt.Sprintf("related user %q does not exists", share.Username)) } buf, err := json.Marshal(share) if err != nil { return err } return bucket.Put([]byte(share.ShareID), buf) }) } func (p *BoltProvider) deleteShare(share Share) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getSharesBucket(tx) if err != nil { return err } var s []byte if s = bucket.Get([]byte(share.ShareID)); s == nil { return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", share.ShareID)) } var oldObject Share if err = json.Unmarshal(s, &oldObject); err != nil { return err } if oldObject.Username != share.Username { return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", share.ShareID)) } return bucket.Delete([]byte(share.ShareID)) }) } func (p *BoltProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { shares := make([]Share, 0, limit) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getSharesBucket(tx) if err != nil { return err } cursor := bucket.Cursor() itNum := 0 if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var share Share if err := json.Unmarshal(v, &share); err != nil { return err } if share.Username != username { continue } itNum++ if itNum <= offset { continue } share.HideConfidentialData() shares = append(shares, share) if len(shares) >= limit { break } } return nil } for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { var share Share err = json.Unmarshal(v, &share) if err != nil { return err } if share.Username != username { continue } itNum++ if itNum <= offset { continue } share.HideConfidentialData() shares = append(shares, share) if len(shares) >= limit { break } } return nil }) return shares, err } func (p *BoltProvider) dumpShares() ([]Share, error) { shares := make([]Share, 0, 30) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getSharesBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var share Share err = json.Unmarshal(v, &share) if err != nil { return err } shares = append(shares, share) } return err }) return shares, err } func (p *BoltProvider) updateShareLastUse(shareID string, numTokens int) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getSharesBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(shareID)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("share %q does not exist, unable to update last use", shareID)) } var share Share err = json.Unmarshal(u, &share) if err != nil { return err } share.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now()) share.UsedTokens += numTokens buf, err := json.Marshal(share) if err != nil { return err } err = bucket.Put([]byte(shareID), buf) if err != nil { providerLog(logger.LevelWarn, "error updating last use for share %q: %v", shareID, err) return err } providerLog(logger.LevelDebug, "last use updated for share %q", shareID) return nil }) } func (p *BoltProvider) getDefenderHosts(_ int64, _ int) ([]DefenderEntry, error) { return nil, ErrNotImplemented } func (p *BoltProvider) getDefenderHostByIP(_ string, _ int64) (DefenderEntry, error) { return DefenderEntry{}, ErrNotImplemented } func (p *BoltProvider) isDefenderHostBanned(_ string) (DefenderEntry, error) { return DefenderEntry{}, ErrNotImplemented } func (p *BoltProvider) updateDefenderBanTime(_ string, _ int) error { return ErrNotImplemented } func (p *BoltProvider) deleteDefenderHost(_ string) error { return ErrNotImplemented } func (p *BoltProvider) addDefenderEvent(_ string, _ int) error { return ErrNotImplemented } func (p *BoltProvider) setDefenderBanTime(_ string, _ int64) error { return ErrNotImplemented } func (p *BoltProvider) cleanupDefender(_ int64) error { return ErrNotImplemented } func (p *BoltProvider) addActiveTransfer(_ ActiveTransfer) error { return ErrNotImplemented } func (p *BoltProvider) updateActiveTransferSizes(_, _, _ int64, _ string) error { return ErrNotImplemented } func (p *BoltProvider) removeActiveTransfer(_ int64, _ string) error { return ErrNotImplemented } func (p *BoltProvider) cleanupActiveTransfers(_ time.Time) error { return ErrNotImplemented } func (p *BoltProvider) getActiveTransfers(_ time.Time) ([]ActiveTransfer, error) { return nil, ErrNotImplemented } func (p *BoltProvider) addSharedSession(_ Session) error { return ErrNotImplemented } func (p *BoltProvider) deleteSharedSession(_ string, _ SessionType) error { return ErrNotImplemented } func (p *BoltProvider) getSharedSession(_ string, _ SessionType) (Session, error) { return Session{}, ErrNotImplemented } func (p *BoltProvider) cleanupSharedSessions(_ SessionType, _ int64) error { return ErrNotImplemented } func (p *BoltProvider) getEventActions(limit, offset int, order string, _ bool) ([]BaseEventAction, error) { if limit <= 0 { return nil, nil } actions := make([]BaseEventAction, 0, limit) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getActionsBucket(tx) if err != nil { return err } itNum := 0 cursor := bucket.Cursor() if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { itNum++ if itNum <= offset { continue } var action BaseEventAction err = json.Unmarshal(v, &action) if err != nil { return err } action.PrepareForRendering() actions = append(actions, action) if len(actions) >= limit { break } } } else { for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { itNum++ if itNum <= offset { continue } var action BaseEventAction err = json.Unmarshal(v, &action) if err != nil { return err } action.PrepareForRendering() actions = append(actions, action) if len(actions) >= limit { break } } } return nil }) return actions, err } func (p *BoltProvider) dumpEventActions() ([]BaseEventAction, error) { actions := make([]BaseEventAction, 0, 50) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getActionsBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var action BaseEventAction err = json.Unmarshal(v, &action) if err != nil { return err } actions = append(actions, action) } return nil }) return actions, err } func (p *BoltProvider) eventActionExists(name string) (BaseEventAction, error) { var action BaseEventAction err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getActionsBucket(tx) if err != nil { return err } k := bucket.Get([]byte(name)) if k == nil { return util.NewRecordNotFoundError(fmt.Sprintf("action %q does not exist", name)) } return json.Unmarshal(k, &action) }) return action, err } func (p *BoltProvider) addEventAction(action *BaseEventAction) error { err := action.validate() if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getActionsBucket(tx) if err != nil { return err } if a := bucket.Get([]byte(action.Name)); a != nil { return util.NewI18nError( fmt.Errorf("%w: event action %q already exists", ErrDuplicatedKey, action.Name), util.I18nErrorDuplicatedName, ) } id, err := bucket.NextSequence() if err != nil { return err } action.ID = int64(id) action.Rules = nil buf, err := json.Marshal(action) if err != nil { return err } return bucket.Put([]byte(action.Name), buf) }) } func (p *BoltProvider) updateEventAction(action *BaseEventAction) error { err := action.validate() if err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getActionsBucket(tx) if err != nil { return err } var a []byte if a = bucket.Get([]byte(action.Name)); a == nil { return util.NewRecordNotFoundError(fmt.Sprintf("event action %s does not exist", action.Name)) } var oldAction BaseEventAction err = json.Unmarshal(a, &oldAction) if err != nil { return err } action.ID = oldAction.ID action.Name = oldAction.Name action.Rules = nil if len(oldAction.Rules) > 0 { rulesBucket, err := p.getRulesBucket(tx) if err != nil { return err } var relatedRules []string for _, ruleName := range oldAction.Rules { r := rulesBucket.Get([]byte(ruleName)) if r != nil { relatedRules = append(relatedRules, ruleName) var rule EventRule err := json.Unmarshal(r, &rule) if err != nil { return err } rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(rule) if err != nil { return err } if err = rulesBucket.Put([]byte(rule.Name), buf); err != nil { return err } setLastRuleUpdate() } } action.Rules = relatedRules } buf, err := json.Marshal(action) if err != nil { return err } return bucket.Put([]byte(action.Name), buf) }) } func (p *BoltProvider) deleteEventAction(action BaseEventAction) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getActionsBucket(tx) if err != nil { return err } var a []byte if a = bucket.Get([]byte(action.Name)); a == nil { return util.NewRecordNotFoundError(fmt.Sprintf("action %s does not exist", action.Name)) } var oldAction BaseEventAction err = json.Unmarshal(a, &oldAction) if err != nil { return err } if len(oldAction.Rules) > 0 { return util.NewValidationError(fmt.Sprintf("action %s is referenced, it cannot be removed", oldAction.Name)) } return bucket.Delete([]byte(action.Name)) }) } func (p *BoltProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { if limit <= 0 { return nil, nil } rules := make([]EventRule, 0, limit) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getRulesBucket(tx) if err != nil { return err } actionsBucket, err := p.getActionsBucket(tx) if err != nil { return err } itNum := 0 cursor := bucket.Cursor() if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { itNum++ if itNum <= offset { continue } var rule EventRule rule, err = p.joinRuleAndActions(v, actionsBucket) if err != nil { return err } rule.PrepareForRendering() rules = append(rules, rule) if len(rules) >= limit { break } } } else { for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { itNum++ if itNum <= offset { continue } var rule EventRule rule, err = p.joinRuleAndActions(v, actionsBucket) if err != nil { return err } rule.PrepareForRendering() rules = append(rules, rule) if len(rules) >= limit { break } } } return err }) return rules, err } func (p *BoltProvider) dumpEventRules() ([]EventRule, error) { rules := make([]EventRule, 0, 50) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getRulesBucket(tx) if err != nil { return err } actionsBucket, err := p.getActionsBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { rule, err := p.joinRuleAndActions(v, actionsBucket) if err != nil { return err } rules = append(rules, rule) } return nil }) return rules, err } func (p *BoltProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { if getLastRuleUpdate() < after { return nil, nil } rules := make([]EventRule, 0, 10) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getRulesBucket(tx) if err != nil { return err } actionsBucket, err := p.getActionsBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var rule EventRule err := json.Unmarshal(v, &rule) if err != nil { return err } if rule.UpdatedAt < after { continue } var actions []EventAction for idx := range rule.Actions { action := &rule.Actions[idx] var baseAction BaseEventAction k := actionsBucket.Get([]byte(action.Name)) if k == nil { continue } err = json.Unmarshal(k, &baseAction) if err != nil { continue } baseAction.Options.SetEmptySecretsIfNil() action.BaseEventAction = baseAction actions = append(actions, *action) } rule.Actions = actions rules = append(rules, rule) } return nil }) return rules, err } func (p *BoltProvider) eventRuleExists(name string) (EventRule, error) { var rule EventRule err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getRulesBucket(tx) if err != nil { return err } r := bucket.Get([]byte(name)) if r == nil { return util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", name)) } actionsBucket, err := p.getActionsBucket(tx) if err != nil { return err } rule, err = p.joinRuleAndActions(r, actionsBucket) return err }) return rule, err } func (p *BoltProvider) addEventRule(rule *EventRule) error { if err := rule.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getRulesBucket(tx) if err != nil { return err } actionsBucket, err := p.getActionsBucket(tx) if err != nil { return err } if r := bucket.Get([]byte(rule.Name)); r != nil { return util.NewI18nError( fmt.Errorf("%w: event rule %q already exists", ErrDuplicatedKey, rule.Name), util.I18nErrorDuplicatedName, ) } id, err := bucket.NextSequence() if err != nil { return err } rule.ID = int64(id) rule.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) rule.UpdatedAt = rule.CreatedAt for idx := range rule.Actions { if err = p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name, actionsBucket); err != nil { return err } } sort.Slice(rule.Actions, func(i, j int) bool { return rule.Actions[i].Order < rule.Actions[j].Order }) buf, err := json.Marshal(rule) if err != nil { return err } err = bucket.Put([]byte(rule.Name), buf) if err == nil { setLastRuleUpdate() } return err }) } func (p *BoltProvider) updateEventRule(rule *EventRule) error { if err := rule.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getRulesBucket(tx) if err != nil { return err } actionsBucket, err := p.getActionsBucket(tx) if err != nil { return err } var r []byte if r = bucket.Get([]byte(rule.Name)); r == nil { return util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", rule.Name)) } var oldRule EventRule if err = json.Unmarshal(r, &oldRule); err != nil { return err } for idx := range oldRule.Actions { if err = p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name, actionsBucket); err != nil { return err } } for idx := range rule.Actions { if err = p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name, actionsBucket); err != nil { return err } } rule.ID = oldRule.ID rule.CreatedAt = oldRule.CreatedAt rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(rule) if err != nil { return err } sort.Slice(rule.Actions, func(i, j int) bool { return rule.Actions[i].Order < rule.Actions[j].Order }) err = bucket.Put([]byte(rule.Name), buf) if err == nil { setLastRuleUpdate() } return err }) } func (p *BoltProvider) deleteEventRule(rule EventRule, _ bool) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getRulesBucket(tx) if err != nil { return err } var r []byte if r = bucket.Get([]byte(rule.Name)); r == nil { return util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", rule.Name)) } var oldRule EventRule if err = json.Unmarshal(r, &oldRule); err != nil { return err } if len(oldRule.Actions) > 0 { actionsBucket, err := p.getActionsBucket(tx) if err != nil { return err } for idx := range oldRule.Actions { if err = p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name, actionsBucket); err != nil { return err } } } return bucket.Delete([]byte(rule.Name)) }) } func (*BoltProvider) getTaskByName(_ string) (Task, error) { return Task{}, ErrNotImplemented } func (*BoltProvider) addTask(_ string) error { return ErrNotImplemented } func (*BoltProvider) updateTask(_ string, _ int64) error { return ErrNotImplemented } func (*BoltProvider) updateTaskTimestamp(_ string) error { return ErrNotImplemented } func (*BoltProvider) addNode() error { return ErrNotImplemented } func (*BoltProvider) getNodeByName(_ string) (Node, error) { return Node{}, ErrNotImplemented } func (*BoltProvider) getNodes() ([]Node, error) { return nil, ErrNotImplemented } func (*BoltProvider) updateNodeTimestamp() error { return ErrNotImplemented } func (*BoltProvider) cleanupNodes() error { return ErrNotImplemented } func (p *BoltProvider) roleExists(name string) (Role, error) { var role Role err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getRolesBucket(tx) if err != nil { return err } r := bucket.Get([]byte(name)) if r == nil { return util.NewRecordNotFoundError(fmt.Sprintf("role %q does not exist", name)) } return json.Unmarshal(r, &role) }) return role, err } func (p *BoltProvider) addRole(role *Role) error { if err := role.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getRolesBucket(tx) if err != nil { return err } if r := bucket.Get([]byte(role.Name)); r != nil { return util.NewI18nError( fmt.Errorf("%w: role %q already exists", ErrDuplicatedKey, role.Name), util.I18nErrorDuplicatedName, ) } id, err := bucket.NextSequence() if err != nil { return err } role.ID = int64(id) role.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) role.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) role.Users = nil role.Admins = nil buf, err := json.Marshal(role) if err != nil { return err } return bucket.Put([]byte(role.Name), buf) }) } func (p *BoltProvider) updateRole(role *Role) error { if err := role.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getRolesBucket(tx) if err != nil { return err } var r []byte if r = bucket.Get([]byte(role.Name)); r == nil { return fmt.Errorf("role %q does not exist", role.Name) } var oldRole Role err = json.Unmarshal(r, &oldRole) if err != nil { return err } role.ID = oldRole.ID role.CreatedAt = oldRole.CreatedAt role.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) role.Users = oldRole.Users role.Admins = oldRole.Admins buf, err := json.Marshal(role) if err != nil { return err } return bucket.Put([]byte(role.Name), buf) }) } func (p *BoltProvider) deleteRole(role Role) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getRolesBucket(tx) if err != nil { return err } var r []byte if r = bucket.Get([]byte(role.Name)); r == nil { return fmt.Errorf("role %q does not exist", role.Name) } var oldRole Role err = json.Unmarshal(r, &oldRole) if err != nil { return err } if len(oldRole.Admins) > 0 { return util.NewValidationError(fmt.Sprintf("the role %q is referenced, it cannot be removed", oldRole.Name)) } if len(oldRole.Users) > 0 { bucket, err := p.getUsersBucket(tx) if err != nil { return err } for _, username := range oldRole.Users { if err := p.removeRoleFromUser(username, oldRole.Name, bucket); err != nil { return err } } } return bucket.Delete([]byte(role.Name)) }) } func (p *BoltProvider) getRoles(limit int, offset int, order string, _ bool) ([]Role, error) { roles := make([]Role, 0, limit) if limit <= 0 { return roles, nil } err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getRolesBucket(tx) if err != nil { return err } cursor := bucket.Cursor() itNum := 0 if order == OrderASC { for k, v := cursor.First(); k != nil; k, v = cursor.Next() { itNum++ if itNum <= offset { continue } var role Role err = json.Unmarshal(v, &role) if err != nil { return err } roles = append(roles, role) if len(roles) >= limit { break } } } else { for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { itNum++ if itNum <= offset { continue } var role Role err = json.Unmarshal(v, &role) if err != nil { return err } roles = append(roles, role) if len(roles) >= limit { break } } } return nil }) return roles, err } func (p *BoltProvider) dumpRoles() ([]Role, error) { roles := make([]Role, 0, 10) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getRolesBucket(tx) if err != nil { return err } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var role Role err = json.Unmarshal(v, &role) if err != nil { return err } roles = append(roles, role) } return err }) return roles, err } func (p *BoltProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { entry := IPListEntry{ IPOrNet: ipOrNet, Type: listType, } err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getIPListsBucket(tx) if err != nil { return err } e := bucket.Get([]byte(entry.getKey())) if e == nil { return util.NewRecordNotFoundError(fmt.Sprintf("entry %q does not exist", entry.IPOrNet)) } err = json.Unmarshal(e, &entry) if err == nil { entry.PrepareForRendering() } return err }) return entry, err } func (p *BoltProvider) addIPListEntry(entry *IPListEntry) error { if err := entry.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getIPListsBucket(tx) if err != nil { return err } if e := bucket.Get([]byte(entry.getKey())); e != nil { return util.NewI18nError( fmt.Errorf("%w: entry %q already exists", ErrDuplicatedKey, entry.IPOrNet), util.I18nErrorDuplicatedIPNet, ) } entry.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) entry.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(entry) if err != nil { return err } return bucket.Put([]byte(entry.getKey()), buf) }) } func (p *BoltProvider) updateIPListEntry(entry *IPListEntry) error { if err := entry.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getIPListsBucket(tx) if err != nil { return err } var e []byte if e = bucket.Get([]byte(entry.getKey())); e == nil { return fmt.Errorf("entry %q does not exist", entry.IPOrNet) } var oldEntry IPListEntry err = json.Unmarshal(e, &oldEntry) if err != nil { return err } entry.CreatedAt = oldEntry.CreatedAt entry.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(entry) if err != nil { return err } return bucket.Put([]byte(entry.getKey()), buf) }) } func (p *BoltProvider) deleteIPListEntry(entry IPListEntry, _ bool) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getIPListsBucket(tx) if err != nil { return err } if e := bucket.Get([]byte(entry.getKey())); e == nil { return fmt.Errorf("entry %q does not exist", entry.IPOrNet) } return bucket.Delete([]byte(entry.getKey())) }) } func (p *BoltProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { entries := make([]IPListEntry, 0, 15) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getIPListsBucket(tx) if err != nil { return err } prefix := []byte(fmt.Sprintf("%d_", listType)) acceptKey := func(k []byte) bool { return k != nil && bytes.HasPrefix(k, prefix) } cursor := bucket.Cursor() if order == OrderASC { for k, v := cursor.Seek(prefix); acceptKey(k); k, v = cursor.Next() { var entry IPListEntry err = json.Unmarshal(v, &entry) if err != nil { return err } if entry.satisfySearchConstraints(filter, from, order) { entry.PrepareForRendering() entries = append(entries, entry) if limit > 0 && len(entries) >= limit { break } } } } else { for k, v := cursor.Last(); acceptKey(k); k, v = cursor.Prev() { var entry IPListEntry err = json.Unmarshal(v, &entry) if err != nil { return err } if entry.satisfySearchConstraints(filter, from, order) { entry.PrepareForRendering() entries = append(entries, entry) if limit > 0 && len(entries) >= limit { break } } } } return nil }) return entries, err } func (p *BoltProvider) getRecentlyUpdatedIPListEntries(_ int64) ([]IPListEntry, error) { return nil, ErrNotImplemented } func (p *BoltProvider) dumpIPListEntries() ([]IPListEntry, error) { entries := make([]IPListEntry, 0, 10) err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getIPListsBucket(tx) if err != nil { return err } if count := bucket.Stats().KeyN; count > ipListMemoryLimit { providerLog(logger.LevelInfo, "IP lists excluded from dump, too many entries: %d", count) return nil } cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var entry IPListEntry err = json.Unmarshal(v, &entry) if err != nil { return err } entry.PrepareForRendering() entries = append(entries, entry) } return nil }) return entries, err } func (p *BoltProvider) countIPListEntries(listType IPListType) (int64, error) { var count int64 err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getIPListsBucket(tx) if err != nil { return err } if listType == 0 { count = int64(bucket.Stats().KeyN) return nil } prefix := []byte(fmt.Sprintf("%d_", listType)) cursor := bucket.Cursor() for k, _ := cursor.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = cursor.Next() { count++ } return nil }) return count, err } func (p *BoltProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { entries := make([]IPListEntry, 0, 3) ipAddr, err := netip.ParseAddr(ip) if err != nil { return entries, fmt.Errorf("invalid ip address %s", ip) } var netType int var ipBytes []byte if ipAddr.Is4() || ipAddr.Is4In6() { netType = ipTypeV4 as4 := ipAddr.As4() ipBytes = as4[:] } else { netType = ipTypeV6 as16 := ipAddr.As16() ipBytes = as16[:] } err = p.dbHandle.View(func(tx *bolt.Tx) error { bucket, err := p.getIPListsBucket(tx) if err != nil { return err } prefix := []byte(fmt.Sprintf("%d_", listType)) cursor := bucket.Cursor() for k, v := cursor.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = cursor.Next() { var entry IPListEntry err = json.Unmarshal(v, &entry) if err != nil { return err } if entry.IPType == netType && bytes.Compare(ipBytes, entry.First) >= 0 && bytes.Compare(ipBytes, entry.Last) <= 0 { entry.PrepareForRendering() entries = append(entries, entry) } } return nil }) return entries, err } func (p *BoltProvider) getConfigs() (Configs, error) { var configs Configs err := p.dbHandle.View(func(tx *bolt.Tx) error { bucket := tx.Bucket(configsBucket) if bucket == nil { return fmt.Errorf("unable to find configs bucket") } data := bucket.Get(configsKey) if data != nil { return json.Unmarshal(data, &configs) } return nil }) return configs, err } func (p *BoltProvider) setConfigs(configs *Configs) error { if err := configs.validate(); err != nil { return err } return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket := tx.Bucket(configsBucket) if bucket == nil { return fmt.Errorf("unable to find configs bucket") } buf, err := json.Marshal(configs) if err != nil { return err } return bucket.Put(configsKey, buf) }) } func (p *BoltProvider) setFirstDownloadTimestamp(username string) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to set download timestamp", username)) } var user User err = json.Unmarshal(u, &user) if err != nil { return err } if user.FirstDownload > 0 { return util.NewGenericError(fmt.Sprintf("first download already set to %v", util.GetTimeFromMsecSinceEpoch(user.FirstDownload))) } user.FirstDownload = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(user) if err != nil { return err } return bucket.Put([]byte(username), buf) }) } func (p *BoltProvider) setFirstUploadTimestamp(username string) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } var u []byte if u = bucket.Get([]byte(username)); u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to set upload timestamp", username)) } var user User if err = json.Unmarshal(u, &user); err != nil { return err } if user.FirstUpload > 0 { return util.NewGenericError(fmt.Sprintf("first upload already set to %v", util.GetTimeFromMsecSinceEpoch(user.FirstUpload))) } user.FirstUpload = util.GetTimeAsMsSinceEpoch(time.Now()) buf, err := json.Marshal(user) if err != nil { return err } return bucket.Put([]byte(username), buf) }) } func (p *BoltProvider) close() error { return p.dbHandle.Close() } func (p *BoltProvider) reloadConfig() error { return nil } // initializeDatabase does nothing, no initilization is needed for bolt provider func (p *BoltProvider) initializeDatabase() error { return ErrNoInitRequired } func (p *BoltProvider) migrateDatabase() error { dbVersion, err := getBoltDatabaseVersion(p.dbHandle) if err != nil { return err } switch version := dbVersion.Version; { case version == boltDatabaseVersion: providerLog(logger.LevelDebug, "bolt database is up to date, current version: %d", version) return ErrNoInitRequired case version < 33: err = errSchemaVersionTooOld(version) providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err case version == 33: logger.InfoToConsole("updating database schema version: %d -> 34", version) providerLog(logger.LevelInfo, "updating database schema version: %d -> 34", version) return updateBoltDatabaseVersion(p.dbHandle, 34) default: if version > boltDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, boltDatabaseVersion) logger.WarnToConsole("database schema version %d is newer than the supported one: %d", version, boltDatabaseVersion) return nil } return fmt.Errorf("database schema version not handled: %d", version) } } func (p *BoltProvider) revertDatabase(targetVersion int) error { dbVersion, err := getBoltDatabaseVersion(p.dbHandle) if err != nil { return err } if dbVersion.Version == targetVersion { return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { case 34: logger.InfoToConsole("downgrading database schema version: %d -> 33", dbVersion.Version) providerLog(logger.LevelInfo, "downgrading database schema version: %d -> 33", dbVersion.Version) return updateBoltDatabaseVersion(p.dbHandle, 33) default: return fmt.Errorf("database schema version not handled: %v", dbVersion.Version) } } func (p *BoltProvider) resetDatabase() error { return p.dbHandle.Update(func(tx *bolt.Tx) error { for _, bucketName := range boltBuckets { err := tx.DeleteBucket(bucketName) if err != nil && !errors.Is(err, bolterrors.ErrBucketNotFound) { return fmt.Errorf("unable to remove bucket %v: %w", bucketName, err) } } return nil }) } func (p *BoltProvider) joinRuleAndActions(r []byte, actionsBucket *bolt.Bucket) (EventRule, error) { var rule EventRule err := json.Unmarshal(r, &rule) if err != nil { return rule, err } var actions []EventAction for idx := range rule.Actions { action := &rule.Actions[idx] var baseAction BaseEventAction k := actionsBucket.Get([]byte(action.Name)) if k == nil { continue } err = json.Unmarshal(k, &baseAction) if err != nil { continue } baseAction.Options.SetEmptySecretsIfNil() action.BaseEventAction = baseAction actions = append(actions, *action) } rule.Actions = actions return rule, nil } func (p *BoltProvider) joinGroupAndFolders(g []byte, foldersBucket *bolt.Bucket) (Group, error) { var group Group err := json.Unmarshal(g, &group) if err != nil { return group, err } if len(group.VirtualFolders) > 0 { var folders []vfs.VirtualFolder for idx := range group.VirtualFolders { folder := &group.VirtualFolders[idx] baseFolder, err := p.folderExistsInternal(folder.Name, foldersBucket) if err != nil { continue } folder.BaseVirtualFolder = baseFolder folders = append(folders, *folder) } group.VirtualFolders = folders } group.SetEmptySecretsIfNil() return group, err } func (p *BoltProvider) joinUserAndFolders(u []byte, foldersBucket *bolt.Bucket) (User, error) { var user User err := json.Unmarshal(u, &user) if err != nil { return user, err } if len(user.VirtualFolders) > 0 { var folders []vfs.VirtualFolder for idx := range user.VirtualFolders { folder := &user.VirtualFolders[idx] baseFolder, err := p.folderExistsInternal(folder.Name, foldersBucket) if err != nil { continue } folder.BaseVirtualFolder = baseFolder folders = append(folders, *folder) } user.VirtualFolders = folders } user.SetEmptySecretsIfNil() return user, err } func (p *BoltProvider) groupExistsInternal(name string, bucket *bolt.Bucket) (Group, error) { var group Group g := bucket.Get([]byte(name)) if g == nil { err := util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", name)) return group, err } err := json.Unmarshal(g, &group) return group, err } func (p *BoltProvider) folderExistsInternal(name string, bucket *bolt.Bucket) (vfs.BaseVirtualFolder, error) { var folder vfs.BaseVirtualFolder f := bucket.Get([]byte(name)) if f == nil { err := util.NewRecordNotFoundError(fmt.Sprintf("folder %q does not exist", name)) return folder, err } err := json.Unmarshal(f, &folder) return folder, err } func (p *BoltProvider) addFolderInternal(folder vfs.BaseVirtualFolder, bucket *bolt.Bucket) error { id, err := bucket.NextSequence() if err != nil { return err } folder.ID = int64(id) buf, err := json.Marshal(folder) if err != nil { return err } return bucket.Put([]byte(folder.Name), buf) } func (p *BoltProvider) removeRoleFromUser(username, role string, bucket *bolt.Bucket) error { u := bucket.Get([]byte(username)) if u == nil { providerLog(logger.LevelWarn, "user %q does not exist, cannot remove role %q", username, role) return nil } var user User err := json.Unmarshal(u, &user) if err != nil { return err } if user.Role == role { user.Role = "" buf, err := json.Marshal(user) if err != nil { return err } return bucket.Put([]byte(user.Username), buf) } providerLog(logger.LevelError, "user %q does not have the expected role %q, actual %q", username, role, user.Role) return nil } func (p *BoltProvider) addAdminToRole(username, roleName string, bucket *bolt.Bucket) error { if roleName == "" { return nil } r := bucket.Get([]byte(roleName)) if r == nil { return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, roleName) } var role Role err := json.Unmarshal(r, &role) if err != nil { return err } if !slices.Contains(role.Admins, username) { role.Admins = append(role.Admins, username) buf, err := json.Marshal(role) if err != nil { return err } return bucket.Put([]byte(role.Name), buf) } return nil } func (p *BoltProvider) removeAdminFromRole(username, roleName string, bucket *bolt.Bucket) error { if roleName == "" { return nil } r := bucket.Get([]byte(roleName)) if r == nil { providerLog(logger.LevelWarn, "role %q does not exist, cannot remove admin %q", roleName, username) return nil } var role Role err := json.Unmarshal(r, &role) if err != nil { return err } if slices.Contains(role.Admins, username) { var admins []string for _, admin := range role.Admins { if admin != username { admins = append(admins, admin) } } role.Admins = util.RemoveDuplicates(admins, false) buf, err := json.Marshal(role) if err != nil { return err } return bucket.Put([]byte(role.Name), buf) } return nil } func (p *BoltProvider) addUserToRole(username, roleName string, bucket *bolt.Bucket) error { if roleName == "" { return nil } r := bucket.Get([]byte(roleName)) if r == nil { return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, roleName) } var role Role err := json.Unmarshal(r, &role) if err != nil { return err } if !slices.Contains(role.Users, username) { role.Users = append(role.Users, username) buf, err := json.Marshal(role) if err != nil { return err } return bucket.Put([]byte(role.Name), buf) } return nil } func (p *BoltProvider) removeUserFromRole(username, roleName string, bucket *bolt.Bucket) error { if roleName == "" { return nil } r := bucket.Get([]byte(roleName)) if r == nil { providerLog(logger.LevelWarn, "role %q does not exist, cannot remove admin %q", roleName, username) return nil } var role Role err := json.Unmarshal(r, &role) if err != nil { return err } if slices.Contains(role.Users, username) { var users []string for _, user := range role.Users { if user != username { users = append(users, user) } } users = util.RemoveDuplicates(users, false) role.Users = users buf, err := json.Marshal(role) if err != nil { return err } return bucket.Put([]byte(role.Name), buf) } return nil } func (p *BoltProvider) addRuleToActionMapping(ruleName, actionName string, bucket *bolt.Bucket) error { a := bucket.Get([]byte(actionName)) if a == nil { return util.NewGenericError(fmt.Sprintf("action %q does not exist", actionName)) } var action BaseEventAction err := json.Unmarshal(a, &action) if err != nil { return err } if !slices.Contains(action.Rules, ruleName) { action.Rules = append(action.Rules, ruleName) buf, err := json.Marshal(action) if err != nil { return err } return bucket.Put([]byte(action.Name), buf) } return nil } func (p *BoltProvider) removeRuleFromActionMapping(ruleName, actionName string, bucket *bolt.Bucket) error { a := bucket.Get([]byte(actionName)) if a == nil { providerLog(logger.LevelWarn, "action %q does not exist, cannot remove from mapping", actionName) return nil } var action BaseEventAction err := json.Unmarshal(a, &action) if err != nil { return err } if slices.Contains(action.Rules, ruleName) { var rules []string for _, r := range action.Rules { if r != ruleName { rules = append(rules, r) } } action.Rules = util.RemoveDuplicates(rules, false) buf, err := json.Marshal(action) if err != nil { return err } return bucket.Put([]byte(action.Name), buf) } return nil } func (p *BoltProvider) addUserToGroupMapping(username, groupname string, bucket *bolt.Bucket) error { g := bucket.Get([]byte(groupname)) if g == nil { return util.NewGenericError(fmt.Sprintf("group %q does not exist", groupname)) } var group Group err := json.Unmarshal(g, &group) if err != nil { return err } if !slices.Contains(group.Users, username) { group.Users = append(group.Users, username) buf, err := json.Marshal(group) if err != nil { return err } return bucket.Put([]byte(group.Name), buf) } return nil } func (p *BoltProvider) removeUserFromGroupMapping(username, groupname string, bucket *bolt.Bucket) error { g := bucket.Get([]byte(groupname)) if g == nil { return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", groupname)) } var group Group err := json.Unmarshal(g, &group) if err != nil { return err } var users []string for _, u := range group.Users { if u != username { users = append(users, u) } } group.Users = util.RemoveDuplicates(users, false) buf, err := json.Marshal(group) if err != nil { return err } return bucket.Put([]byte(group.Name), buf) } func (p *BoltProvider) addAdminToGroupMapping(username, groupname string, bucket *bolt.Bucket) error { g := bucket.Get([]byte(groupname)) if g == nil { return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", groupname)) } var group Group err := json.Unmarshal(g, &group) if err != nil { return err } if !slices.Contains(group.Admins, username) { group.Admins = append(group.Admins, username) buf, err := json.Marshal(group) if err != nil { return err } return bucket.Put([]byte(group.Name), buf) } return nil } func (p *BoltProvider) removeAdminFromGroupMapping(username, groupname string, bucket *bolt.Bucket) error { g := bucket.Get([]byte(groupname)) if g == nil { return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", groupname)) } var group Group err := json.Unmarshal(g, &group) if err != nil { return err } var admins []string for _, a := range group.Admins { if a != username { admins = append(admins, a) } } group.Admins = util.RemoveDuplicates(admins, false) buf, err := json.Marshal(group) if err != nil { return err } return bucket.Put([]byte(group.Name), buf) } func (p *BoltProvider) removeGroupFromAdminMapping(groupName, adminName string, bucket *bolt.Bucket) error { var a []byte if a = bucket.Get([]byte(adminName)); a == nil { // the admin does not exist so there is no associated group return nil } var admin Admin err := json.Unmarshal(a, &admin) if err != nil { return err } var newGroups []AdminGroupMapping for _, g := range admin.Groups { if g.Name != groupName { newGroups = append(newGroups, g) } } admin.Groups = newGroups buf, err := json.Marshal(admin) if err != nil { return err } return bucket.Put([]byte(adminName), buf) } func (p *BoltProvider) addRelationToFolderMapping(folderName string, user *User, group *Group, bucket *bolt.Bucket) error { f := bucket.Get([]byte(folderName)) if f == nil { return util.NewGenericError(fmt.Sprintf("folder %q does not exist", folderName)) } var folder vfs.BaseVirtualFolder err := json.Unmarshal(f, &folder) if err != nil { return err } updated := false if user != nil && !slices.Contains(folder.Users, user.Username) { folder.Users = append(folder.Users, user.Username) updated = true } if group != nil && !slices.Contains(folder.Groups, group.Name) { folder.Groups = append(folder.Groups, group.Name) updated = true } if !updated { return nil } buf, err := json.Marshal(folder) if err != nil { return err } return bucket.Put([]byte(folder.Name), buf) } func (p *BoltProvider) removeRelationFromFolderMapping(folder vfs.VirtualFolder, username, groupname string, bucket *bolt.Bucket, ) error { var f []byte if f = bucket.Get([]byte(folder.Name)); f == nil { // the folder does not exist so there is no associated user/group return nil } var baseFolder vfs.BaseVirtualFolder err := json.Unmarshal(f, &baseFolder) if err != nil { return err } found := false if username != "" { found = true var newUserMapping []string for _, u := range baseFolder.Users { if u != username { newUserMapping = append(newUserMapping, u) } } baseFolder.Users = newUserMapping } if groupname != "" { found = true var newGroupMapping []string for _, g := range baseFolder.Groups { if g != groupname { newGroupMapping = append(newGroupMapping, g) } } baseFolder.Groups = newGroupMapping } if !found { return nil } buf, err := json.Marshal(baseFolder) if err != nil { return err } return bucket.Put([]byte(folder.Name), buf) } func (p *BoltProvider) updateUserRelations(tx *bolt.Tx, user *User, oldUser User) error { foldersBucket, err := p.getFoldersBucket(tx) if err != nil { return err } groupsBucket, err := p.getGroupsBucket(tx) if err != nil { return err } rolesBucket, err := p.getRolesBucket(tx) if err != nil { return err } for idx := range oldUser.VirtualFolders { err = p.removeRelationFromFolderMapping(oldUser.VirtualFolders[idx], oldUser.Username, "", foldersBucket) if err != nil { return err } } for idx := range oldUser.Groups { err = p.removeUserFromGroupMapping(user.Username, oldUser.Groups[idx].Name, groupsBucket) if err != nil { return err } } if err = p.removeUserFromRole(oldUser.Username, oldUser.Role, rolesBucket); err != nil { return err } for idx := range user.VirtualFolders { err = p.addRelationToFolderMapping(user.VirtualFolders[idx].Name, user, nil, foldersBucket) if err != nil { return err } } for idx := range user.Groups { err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name, groupsBucket) if err != nil { return err } } return p.addUserToRole(user.Username, user.Role, rolesBucket) } func (p *BoltProvider) adminExistsInternal(tx *bolt.Tx, username string) error { bucket, err := p.getAdminsBucket(tx) if err != nil { return err } a := bucket.Get([]byte(username)) if a == nil { return util.NewRecordNotFoundError(fmt.Sprintf("admin %v does not exist", username)) } return nil } func (p *BoltProvider) userExistsInternal(tx *bolt.Tx, username string) error { bucket, err := p.getUsersBucket(tx) if err != nil { return err } u := bucket.Get([]byte(username)) if u == nil { return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } return nil } func (p *BoltProvider) deleteRelatedShares(tx *bolt.Tx, username string) error { bucket, err := p.getSharesBucket(tx) if err != nil { return err } var toRemove []string cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var share Share err = json.Unmarshal(v, &share) if err != nil { return err } if share.Username == username { toRemove = append(toRemove, share.ShareID) } } for _, k := range toRemove { if err := bucket.Delete([]byte(k)); err != nil { return err } } return nil } func (p *BoltProvider) deleteRelatedAPIKey(tx *bolt.Tx, username string, scope APIKeyScope) error { bucket, err := p.getAPIKeysBucket(tx) if err != nil { return err } var toRemove []string cursor := bucket.Cursor() for k, v := cursor.First(); k != nil; k, v = cursor.Next() { var apiKey APIKey err = json.Unmarshal(v, &apiKey) if err != nil { return err } if scope == APIKeyScopeUser { if apiKey.User == username { toRemove = append(toRemove, apiKey.KeyID) } } else { if apiKey.Admin == username { toRemove = append(toRemove, apiKey.KeyID) } } } for _, k := range toRemove { if err := bucket.Delete([]byte(k)); err != nil { return err } } return nil } func (p *BoltProvider) getSharesBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(sharesBucket) if bucket == nil { err = errors.New("unable to find shares bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getAPIKeysBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(apiKeysBucket) if bucket == nil { err = errors.New("unable to find api keys bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getAdminsBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(adminsBucket) if bucket == nil { err = errors.New("unable to find admins bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getUsersBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(usersBucket) if bucket == nil { err = errors.New("unable to find users bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getGroupsBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(groupsBucket) if bucket == nil { err = fmt.Errorf("unable to find groups bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getRolesBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(rolesBucket) if bucket == nil { err = fmt.Errorf("unable to find roles bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getIPListsBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(rolesBucket) if bucket == nil { err = fmt.Errorf("unable to find IP lists bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getFoldersBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(foldersBucket) if bucket == nil { err = fmt.Errorf("unable to find folders bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getActionsBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(actionsBucket) if bucket == nil { err = fmt.Errorf("unable to find event actions bucket, bolt database structure not correcly defined") } return bucket, err } func (p *BoltProvider) getRulesBucket(tx *bolt.Tx) (*bolt.Bucket, error) { var err error bucket := tx.Bucket(rulesBucket) if bucket == nil { err = fmt.Errorf("unable to find event rules bucket, bolt database structure not correcly defined") } return bucket, err } func getBoltDatabaseVersion(dbHandle *bolt.DB) (schemaVersion, error) { var dbVersion schemaVersion err := dbHandle.View(func(tx *bolt.Tx) error { bucket := tx.Bucket(dbVersionBucket) if bucket == nil { return fmt.Errorf("unable to find database schema version bucket") } v := bucket.Get(dbVersionKey) if v == nil { dbVersion = schemaVersion{ Version: 33, } return nil } return json.Unmarshal(v, &dbVersion) }) return dbVersion, err } func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error { err := dbHandle.Update(func(tx *bolt.Tx) error { bucket := tx.Bucket(dbVersionBucket) if bucket == nil { return fmt.Errorf("unable to find database schema version bucket") } newDbVersion := schemaVersion{ Version: version, } buf, err := json.Marshal(newDbVersion) if err != nil { return err } return bucket.Put(dbVersionKey, buf) }) return err } ================================================ FILE: internal/dataprovider/bolt_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build nobolt package dataprovider import ( "errors" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-bolt") } func initializeBoltProvider(_ string) error { return errors.New("bolt disabled at build time") } ================================================ FILE: internal/dataprovider/cachedpassword.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "sort" "sync" "sync/atomic" "time" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( cachedUserPasswords credentialsCache cachedAdminPasswords credentialsCache cachedAPIKeys credentialsCache ) func init() { cachedUserPasswords = credentialsCache{ name: "users", sizeLimit: 500, cache: make(map[string]credentialObject), } cachedAdminPasswords = credentialsCache{ name: "admins", sizeLimit: 100, cache: make(map[string]credentialObject), } cachedAPIKeys = credentialsCache{ name: "API keys", sizeLimit: 500, cache: make(map[string]credentialObject), } } // CheckCachedUserPassword is an utility method used only in test cases func CheckCachedUserPassword(username, password, hash string) (bool, bool) { return cachedUserPasswords.Check(username, password, hash) } type credentialObject struct { key string hash string password string usedAt *atomic.Int64 } type credentialsCache struct { name string sizeLimit int sync.RWMutex cache map[string]credentialObject } func (c *credentialsCache) Add(username, password, hash string) { if !config.PasswordCaching || username == "" || password == "" || hash == "" { return } c.Lock() defer c.Unlock() obj := credentialObject{ key: username, hash: hash, password: password, usedAt: &atomic.Int64{}, } obj.usedAt.Store(util.GetTimeAsMsSinceEpoch(time.Now())) c.cache[username] = obj } func (c *credentialsCache) Remove(username string) { if !config.PasswordCaching { return } c.Lock() defer c.Unlock() delete(c.cache, username) } // Check returns if the username is found and if the password match func (c *credentialsCache) Check(username, password, hash string) (bool, bool) { if username == "" || password == "" || hash == "" { return false, false } c.RLock() defer c.RUnlock() creds, ok := c.cache[username] if !ok { return false, false } if creds.hash != hash { creds.usedAt.Store(0) return false, false } match := creds.password == password if match { creds.usedAt.Store(util.GetTimeAsMsSinceEpoch(time.Now())) } return true, match } func (c *credentialsCache) count() int { c.RLock() defer c.RUnlock() return len(c.cache) } func (c *credentialsCache) cleanup() { if !config.PasswordCaching { return } if c.count() <= c.sizeLimit { return } c.Lock() defer c.Unlock() for k, v := range c.cache { if v.usedAt.Load() < util.GetTimeAsMsSinceEpoch(time.Now().Add(-60*time.Minute)) { delete(c.cache, k) } } providerLog(logger.LevelDebug, "size for credentials %q after cleanup: %d", c.name, len(c.cache)) if len(c.cache) < c.sizeLimit*5 { return } numToRemove := len(c.cache) - c.sizeLimit providerLog(logger.LevelDebug, "additional item to remove from credentials %q: %d", c.name, numToRemove) credentials := make([]credentialObject, 0, len(c.cache)) for _, v := range c.cache { credentials = append(credentials, v) } sort.Slice(credentials, func(i, j int) bool { return credentials[i].usedAt.Load() < credentials[j].usedAt.Load() }) for idx := range credentials { if idx >= numToRemove { break } delete(c.cache, credentials[idx].key) } providerLog(logger.LevelDebug, "size for credentials %q after additional cleanup: %d", c.name, len(c.cache)) } ================================================ FILE: internal/dataprovider/cacheduser.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "sync" "time" "github.com/drakkan/webdav" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( webDAVUsersCache *usersCache ) func init() { webDAVUsersCache = &usersCache{ users: map[string]CachedUser{}, } } // InitializeWebDAVUserCache initializes the cache for webdav users func InitializeWebDAVUserCache(maxSize int) { webDAVUsersCache = &usersCache{ users: map[string]CachedUser{}, maxSize: maxSize, } } // CachedUser adds fields useful for caching to a SFTPGo user type CachedUser struct { User User Expiration time.Time Password string LockSystem webdav.LockSystem } // IsExpired returns true if the cached user is expired func (c *CachedUser) IsExpired() bool { if c.Expiration.IsZero() { return false } return c.Expiration.Before(time.Now()) } type usersCache struct { sync.RWMutex users map[string]CachedUser maxSize int } func (cache *usersCache) updateLastLogin(username string) { cache.Lock() defer cache.Unlock() if cachedUser, ok := cache.users[username]; ok { cachedUser.User.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) cache.users[username] = cachedUser } } // swapWebDAVUser updates an existing cached user with the specified one // preserving the lock fs if possible // FIXME: this could be racy in rare cases func (cache *usersCache) swap(userRef *User, plainPassword string) { user := userRef.getACopy() err := user.LoadAndApplyGroupSettings() cache.Lock() defer cache.Unlock() if cachedUser, ok := cache.users[user.Username]; ok { if err != nil { providerLog(logger.LevelDebug, "unable to load group settings, for user %q, removing from cache, err :%v", user.Username, err) delete(cache.users, user.Username) return } if plainPassword != "" { cachedUser.Password = plainPassword } else { if cachedUser.User.Password != user.Password { providerLog(logger.LevelDebug, "current password different from the cached one for user %q, removing from cache", user.Username) // the password changed, the cached user is no longer valid delete(cache.users, user.Username) return } } if cachedUser.User.isFsEqual(&user) { // the updated user has the same fs as the cached one, we can preserve the lock filesystem providerLog(logger.LevelDebug, "current password and fs unchanged for for user %q, swap cached one", user.Username) cachedUser.User = user cache.users[user.Username] = cachedUser } else { // filesystem changed, the cached user is no longer valid providerLog(logger.LevelDebug, "current fs different from the cached one for user %q, removing from cache", user.Username) delete(cache.users, user.Username) } } } func (cache *usersCache) add(cachedUser *CachedUser) { cache.Lock() defer cache.Unlock() if cache.maxSize > 0 && len(cache.users) >= cache.maxSize { var userToRemove string var expirationTime time.Time for k, v := range cache.users { if userToRemove == "" { userToRemove = k expirationTime = v.Expiration continue } expireTime := v.Expiration if !expireTime.IsZero() && expireTime.Before(expirationTime) { userToRemove = k expirationTime = expireTime } } delete(cache.users, userToRemove) } if cachedUser.User.Username != "" { cache.users[cachedUser.User.Username] = *cachedUser } } func (cache *usersCache) remove(username string) { cache.Lock() defer cache.Unlock() delete(cache.users, username) } func (cache *usersCache) get(username string) (*CachedUser, bool) { cache.RLock() defer cache.RUnlock() cachedUser, ok := cache.users[username] if !ok { return nil, false } return &cachedUser, true } // CacheWebDAVUser add a user to the WebDAV cache func CacheWebDAVUser(cachedUser *CachedUser) { webDAVUsersCache.add(cachedUser) } // GetCachedWebDAVUser returns a previously cached WebDAV user func GetCachedWebDAVUser(username string) (*CachedUser, bool) { return webDAVUsersCache.get(username) } // RemoveCachedWebDAVUser removes a cached WebDAV user func RemoveCachedWebDAVUser(username string) { webDAVUsersCache.remove(username) } ================================================ FILE: internal/dataprovider/configs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "bytes" "encoding/json" "fmt" "image/png" "net/url" "slices" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) // Supported values for host keys, KEXs, ciphers, MACs var ( supportedHostKeyAlgos = []string{ssh.KeyAlgoRSA} supportedPublicKeyAlgos = []string{ssh.KeyAlgoRSA, ssh.InsecureKeyAlgoDSA} //nolint:staticcheck supportedKexAlgos = []string{ ssh.KeyExchangeDH16SHA512, ssh.InsecureKeyExchangeDH14SHA1, ssh.InsecureKeyExchangeDH1SHA1, ssh.InsecureKeyExchangeDHGEXSHA1, } supportedCiphers = []string{ ssh.InsecureCipherAES128CBC, ssh.InsecureCipherTripleDESCBC, } supportedMACs = []string{ ssh.HMACSHA512ETM, ssh.HMACSHA512, ssh.HMACSHA1, ssh.InsecureHMACSHA196, } ) // SFTPDConfigs defines configurations for SFTPD type SFTPDConfigs struct { HostKeyAlgos []string `json:"host_key_algos,omitempty"` PublicKeyAlgos []string `json:"public_key_algos,omitempty"` KexAlgorithms []string `json:"kex_algorithms,omitempty"` Ciphers []string `json:"ciphers,omitempty"` MACs []string `json:"macs,omitempty"` } func (c *SFTPDConfigs) isEmpty() bool { if len(c.HostKeyAlgos) > 0 { return false } if len(c.PublicKeyAlgos) > 0 { return false } if len(c.KexAlgorithms) > 0 { return false } if len(c.Ciphers) > 0 { return false } if len(c.MACs) > 0 { return false } return true } // GetSupportedHostKeyAlgos returns the supported legacy host key algos func (*SFTPDConfigs) GetSupportedHostKeyAlgos() []string { return supportedHostKeyAlgos } // GetSupportedPublicKeyAlgos returns the supported legacy public key algos func (*SFTPDConfigs) GetSupportedPublicKeyAlgos() []string { return supportedPublicKeyAlgos } // GetSupportedKEXAlgos returns the supported KEX algos func (*SFTPDConfigs) GetSupportedKEXAlgos() []string { return supportedKexAlgos } // GetSupportedCiphers returns the supported ciphers func (*SFTPDConfigs) GetSupportedCiphers() []string { return supportedCiphers } // GetSupportedMACs returns the supported MACs algos func (*SFTPDConfigs) GetSupportedMACs() []string { return supportedMACs } func (c *SFTPDConfigs) validate() error { var hostKeyAlgos []string for _, algo := range c.HostKeyAlgos { if algo == ssh.CertAlgoRSAv01 { continue } if !slices.Contains(supportedHostKeyAlgos, algo) { return util.NewValidationError(fmt.Sprintf("unsupported host key algorithm %q", algo)) } hostKeyAlgos = append(hostKeyAlgos, algo) } c.HostKeyAlgos = hostKeyAlgos var kexAlgos []string for _, algo := range c.KexAlgorithms { if algo == "diffie-hellman-group18-sha512" || algo == ssh.KeyExchangeDHGEXSHA256 { continue } if !slices.Contains(supportedKexAlgos, algo) { return util.NewValidationError(fmt.Sprintf("unsupported KEX algorithm %q", algo)) } kexAlgos = append(kexAlgos, algo) } c.KexAlgorithms = kexAlgos for _, cipher := range c.Ciphers { if slices.Contains([]string{"aes192-cbc", "aes256-cbc"}, cipher) { continue } if !slices.Contains(supportedCiphers, cipher) { return util.NewValidationError(fmt.Sprintf("unsupported cipher %q", cipher)) } } for _, mac := range c.MACs { if !slices.Contains(supportedMACs, mac) { return util.NewValidationError(fmt.Sprintf("unsupported MAC algorithm %q", mac)) } } for _, algo := range c.PublicKeyAlgos { if !slices.Contains(supportedPublicKeyAlgos, algo) { return util.NewValidationError(fmt.Sprintf("unsupported public key algorithm %q", algo)) } } return nil } func (c *SFTPDConfigs) getACopy() *SFTPDConfigs { hostKeys := make([]string, len(c.HostKeyAlgos)) copy(hostKeys, c.HostKeyAlgos) publicKeys := make([]string, len(c.PublicKeyAlgos)) copy(publicKeys, c.PublicKeyAlgos) kexs := make([]string, len(c.KexAlgorithms)) copy(kexs, c.KexAlgorithms) ciphers := make([]string, len(c.Ciphers)) copy(ciphers, c.Ciphers) macs := make([]string, len(c.MACs)) copy(macs, c.MACs) return &SFTPDConfigs{ HostKeyAlgos: hostKeys, PublicKeyAlgos: publicKeys, KexAlgorithms: kexs, Ciphers: ciphers, MACs: macs, } } func validateSMTPSecret(secret *kms.Secret, name string) error { if secret.IsRedacted() { return util.NewValidationError(fmt.Sprintf("cannot save a redacted smtp %s", name)) } if secret.IsEncrypted() && !secret.IsValid() { return util.NewValidationError(fmt.Sprintf("invalid encrypted smtp %s", name)) } if !secret.IsEmpty() && !secret.IsValidInput() { return util.NewValidationError(fmt.Sprintf("invalid smtp %s", name)) } if secret.IsPlain() { secret.SetAdditionalData("smtp") if err := secret.Encrypt(); err != nil { return util.NewValidationError(fmt.Sprintf("could not encrypt smtp %s: %v", name, err)) } } return nil } // SMTPOAuth2 defines the SMTP related OAuth2 configurations type SMTPOAuth2 struct { Provider int `json:"provider,omitempty"` Tenant string `json:"tenant,omitempty"` ClientID string `json:"client_id,omitempty"` ClientSecret *kms.Secret `json:"client_secret,omitempty"` RefreshToken *kms.Secret `json:"refresh_token,omitempty"` } func (c *SMTPOAuth2) validate() error { if c.Provider < 0 || c.Provider > 1 { return util.NewValidationError("smtp oauth2: unsupported provider") } if c.ClientID == "" { return util.NewI18nError( util.NewValidationError("smtp oauth2: client id is required"), util.I18nErrorClientIDRequired, ) } if c.ClientSecret == nil || c.ClientSecret.IsEmpty() { return util.NewI18nError( util.NewValidationError("smtp oauth2: client secret is required"), util.I18nErrorClientSecretRequired, ) } if c.RefreshToken == nil || c.RefreshToken.IsEmpty() { return util.NewI18nError( util.NewValidationError("smtp oauth2: refresh token is required"), util.I18nErrorRefreshTokenRequired, ) } if err := validateSMTPSecret(c.ClientSecret, "oauth2 client secret"); err != nil { return err } return validateSMTPSecret(c.RefreshToken, "oauth2 refresh token") } func (c *SMTPOAuth2) getACopy() SMTPOAuth2 { var clientSecret, refreshToken *kms.Secret if c.ClientSecret != nil { clientSecret = c.ClientSecret.Clone() } if c.RefreshToken != nil { refreshToken = c.RefreshToken.Clone() } return SMTPOAuth2{ Provider: c.Provider, Tenant: c.Tenant, ClientID: c.ClientID, ClientSecret: clientSecret, RefreshToken: refreshToken, } } // SMTPConfigs defines configuration for SMTP type SMTPConfigs struct { Host string `json:"host,omitempty"` Port int `json:"port,omitempty"` From string `json:"from,omitempty"` User string `json:"user,omitempty"` Password *kms.Secret `json:"password,omitempty"` AuthType int `json:"auth_type,omitempty"` Encryption int `json:"encryption,omitempty"` Domain string `json:"domain,omitempty"` Debug int `json:"debug,omitempty"` OAuth2 SMTPOAuth2 `json:"oauth2"` } // IsEmpty returns true if the configuration is empty func (c *SMTPConfigs) IsEmpty() bool { return c.Host == "" } func (c *SMTPConfigs) validate() error { if c.IsEmpty() { return nil } if c.Port <= 0 || c.Port > 65535 { return util.NewValidationError(fmt.Sprintf("smtp: invalid port %d", c.Port)) } if c.Password != nil && c.AuthType != 3 { if err := validateSMTPSecret(c.Password, "password"); err != nil { return err } } if c.User == "" && c.From == "" { return util.NewI18nError( util.NewValidationError("smtp: from address and user cannot both be empty"), util.I18nErrorSMTPRequiredFields, ) } if c.AuthType < 0 || c.AuthType > 3 { return util.NewValidationError(fmt.Sprintf("smtp: invalid auth type %d", c.AuthType)) } if c.Encryption < 0 || c.Encryption > 2 { return util.NewValidationError(fmt.Sprintf("smtp: invalid encryption %d", c.Encryption)) } if c.AuthType == 3 { c.Password = kms.NewEmptySecret() return c.OAuth2.validate() } c.OAuth2 = SMTPOAuth2{} return nil } // TryDecrypt tries to decrypt the encrypted secrets func (c *SMTPConfigs) TryDecrypt() error { if c.Password == nil { c.Password = kms.NewEmptySecret() } if c.OAuth2.ClientSecret == nil { c.OAuth2.ClientSecret = kms.NewEmptySecret() } if c.OAuth2.RefreshToken == nil { c.OAuth2.RefreshToken = kms.NewEmptySecret() } if err := c.Password.TryDecrypt(); err != nil { return fmt.Errorf("unable to decrypt smtp password: %w", err) } if err := c.OAuth2.ClientSecret.TryDecrypt(); err != nil { return fmt.Errorf("unable to decrypt smtp oauth2 client secret: %w", err) } if err := c.OAuth2.RefreshToken.TryDecrypt(); err != nil { return fmt.Errorf("unable to decrypt smtp oauth2 refresh token: %w", err) } return nil } func (c *SMTPConfigs) prepareForRendering() { if c.Password != nil { c.Password.Hide() if c.Password.IsEmpty() { c.Password = nil } } if c.OAuth2.ClientSecret != nil { c.OAuth2.ClientSecret.Hide() if c.OAuth2.ClientSecret.IsEmpty() { c.OAuth2.ClientSecret = nil } } if c.OAuth2.RefreshToken != nil { c.OAuth2.RefreshToken.Hide() if c.OAuth2.RefreshToken.IsEmpty() { c.OAuth2.RefreshToken = nil } } } func (c *SMTPConfigs) getACopy() *SMTPConfigs { var password *kms.Secret if c.Password != nil { password = c.Password.Clone() } return &SMTPConfigs{ Host: c.Host, Port: c.Port, From: c.From, User: c.User, Password: password, AuthType: c.AuthType, Encryption: c.Encryption, Domain: c.Domain, Debug: c.Debug, OAuth2: c.OAuth2.getACopy(), } } // ACMEHTTP01Challenge defines the configuration for HTTP-01 challenge type type ACMEHTTP01Challenge struct { Port int `json:"port"` } // ACMEConfigs defines ACME related configuration type ACMEConfigs struct { Domain string `json:"domain"` Email string `json:"email"` HTTP01Challenge ACMEHTTP01Challenge `json:"http01_challenge"` // apply the certificate for the specified protocols: // // 1 means HTTP // 2 means FTP // 4 means WebDAV // // Protocols can be combined Protocols int `json:"protocols"` } func (c *ACMEConfigs) isEmpty() bool { return c.Domain == "" } func (c *ACMEConfigs) validate() error { if c.Domain == "" { return nil } if c.Email == "" && !util.IsEmailValid(c.Email) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("acme: invalid email %q", c.Email)), util.I18nErrorInvalidEmail, ) } if c.HTTP01Challenge.Port <= 0 || c.HTTP01Challenge.Port > 65535 { return util.NewValidationError(fmt.Sprintf("acme: invalid HTTP-01 challenge port %d", c.HTTP01Challenge.Port)) } return nil } // HasProtocol returns true if the ACME certificate must be used for the specified protocol func (c *ACMEConfigs) HasProtocol(protocol string) bool { switch protocol { case protocolHTTP: return c.Protocols&1 != 0 case protocolFTP: return c.Protocols&2 != 0 case protocolWebDAV: return c.Protocols&4 != 0 default: return false } } func (c *ACMEConfigs) getACopy() *ACMEConfigs { return &ACMEConfigs{ Email: c.Email, Domain: c.Domain, HTTP01Challenge: ACMEHTTP01Challenge{Port: c.HTTP01Challenge.Port}, Protocols: c.Protocols, } } // BrandingConfig defines the branding configuration type BrandingConfig struct { Name string `json:"name"` ShortName string `json:"short_name"` Logo []byte `json:"logo"` Favicon []byte `json:"favicon"` DisclaimerName string `json:"disclaimer_name"` DisclaimerURL string `json:"disclaimer_url"` } func (c *BrandingConfig) isEmpty() bool { if c.Name != "" { return false } if c.ShortName != "" { return false } if len(c.Logo) > 0 { return false } if len(c.Favicon) > 0 { return false } if c.DisclaimerName != "" && c.DisclaimerURL != "" { return false } return true } func (*BrandingConfig) validatePNG(b []byte, maxWidth, maxHeight int) error { if len(b) == 0 { return nil } // DecodeConfig is more efficient, but I'm not sure if this would lead to // accepting invalid images in some edge cases and performance does not // matter here. img, err := png.Decode(bytes.NewBuffer(b)) if err != nil { return util.NewI18nError( util.NewValidationError("invalid PNG image"), util.I18nErrorInvalidPNG, ) } bounds := img.Bounds() if bounds.Dx() > maxWidth || bounds.Dy() > maxHeight { return util.NewI18nError( util.NewValidationError("invalid PNG image size"), util.I18nErrorInvalidPNGSize, ) } return nil } func (c *BrandingConfig) validateDisclaimerURL() error { if c.DisclaimerURL == "" { return nil } u, err := url.Parse(c.DisclaimerURL) if err != nil { return util.NewI18nError( util.NewValidationError("invalid disclaimer URL"), util.I18nErrorInvalidDisclaimerURL, ) } if u.Scheme != "http" && u.Scheme != "https" { return util.NewI18nError( util.NewValidationError("invalid disclaimer URL scheme"), util.I18nErrorInvalidDisclaimerURL, ) } return nil } func (c *BrandingConfig) validate() error { if err := c.validateDisclaimerURL(); err != nil { return err } if err := c.validatePNG(c.Logo, 512, 512); err != nil { return err } return c.validatePNG(c.Favicon, 256, 256) } func (c *BrandingConfig) getACopy() BrandingConfig { logo := make([]byte, len(c.Logo)) copy(logo, c.Logo) favicon := make([]byte, len(c.Favicon)) copy(favicon, c.Favicon) return BrandingConfig{ Name: c.Name, ShortName: c.ShortName, Logo: logo, Favicon: favicon, DisclaimerName: c.DisclaimerName, DisclaimerURL: c.DisclaimerURL, } } // BrandingConfigs defines the branding configuration for WebAdmin and WebClient UI type BrandingConfigs struct { WebAdmin BrandingConfig WebClient BrandingConfig } func (c *BrandingConfigs) isEmpty() bool { return c.WebAdmin.isEmpty() && c.WebClient.isEmpty() } func (c *BrandingConfigs) validate() error { if err := c.WebAdmin.validate(); err != nil { return err } return c.WebClient.validate() } func (c *BrandingConfigs) getACopy() *BrandingConfigs { return &BrandingConfigs{ WebAdmin: c.WebAdmin.getACopy(), WebClient: c.WebClient.getACopy(), } } // Configs allows to set configuration keys disabled by default without // modifying the config file or setting env vars type Configs struct { SFTPD *SFTPDConfigs `json:"sftpd,omitempty"` SMTP *SMTPConfigs `json:"smtp,omitempty"` ACME *ACMEConfigs `json:"acme,omitempty"` Branding *BrandingConfigs `json:"branding,omitempty"` UpdatedAt int64 `json:"updated_at,omitempty"` } func (c *Configs) validate() error { if c.SFTPD != nil { if err := c.SFTPD.validate(); err != nil { return err } } if c.SMTP != nil { if err := c.SMTP.validate(); err != nil { return err } } if c.ACME != nil { if err := c.ACME.validate(); err != nil { return err } } if c.Branding != nil { if err := c.Branding.validate(); err != nil { return err } } return nil } // PrepareForRendering prepares configs for rendering. // It hides confidential data and set to nil the empty structs/secrets // so they are not serialized func (c *Configs) PrepareForRendering() { if c.SFTPD != nil && c.SFTPD.isEmpty() { c.SFTPD = nil } if c.SMTP != nil && c.SMTP.IsEmpty() { c.SMTP = nil } if c.ACME != nil && c.ACME.isEmpty() { c.ACME = nil } if c.Branding != nil && c.Branding.isEmpty() { c.Branding = nil } if c.SMTP != nil { c.SMTP.prepareForRendering() } } // SetNilsToEmpty sets nil fields to empty func (c *Configs) SetNilsToEmpty() { if c.SFTPD == nil { c.SFTPD = &SFTPDConfigs{} } if c.SMTP == nil { c.SMTP = &SMTPConfigs{} } if c.SMTP.Password == nil { c.SMTP.Password = kms.NewEmptySecret() } if c.SMTP.OAuth2.ClientSecret == nil { c.SMTP.OAuth2.ClientSecret = kms.NewEmptySecret() } if c.SMTP.OAuth2.RefreshToken == nil { c.SMTP.OAuth2.RefreshToken = kms.NewEmptySecret() } if c.ACME == nil { c.ACME = &ACMEConfigs{} } if c.Branding == nil { c.Branding = &BrandingConfigs{} } } // RenderAsJSON implements the renderer interface used within plugins func (c *Configs) RenderAsJSON(reload bool) ([]byte, error) { if reload { config, err := provider.getConfigs() if err != nil { providerLog(logger.LevelError, "unable to reload config overrides before rendering as json: %v", err) return nil, err } config.PrepareForRendering() return json.Marshal(config) } c.PrepareForRendering() return json.Marshal(c) } func (c *Configs) getACopy() Configs { var result Configs if c.SFTPD != nil { result.SFTPD = c.SFTPD.getACopy() } if c.SMTP != nil { result.SMTP = c.SMTP.getACopy() } if c.ACME != nil { result.ACME = c.ACME.getACopy() } if c.Branding != nil { result.Branding = c.Branding.getACopy() } result.UpdatedAt = c.UpdatedAt return result } ================================================ FILE: internal/dataprovider/dataprovider.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package dataprovider provides data access. // It abstracts different data providers using a common API. package dataprovider import ( "bufio" "bytes" "context" "crypto/md5" "crypto/rsa" "crypto/sha1" "crypto/sha256" "crypto/sha512" "crypto/subtle" "crypto/x509" "encoding/base64" "encoding/hex" "encoding/json" "encoding/pem" "errors" "fmt" "hash" "io" "net" "net/http" "net/url" "os" "os/exec" "path" "path/filepath" "regexp" "runtime" "slices" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/GehirnInc/crypt" "github.com/GehirnInc/crypt/apr1_crypt" "github.com/GehirnInc/crypt/md5_crypt" "github.com/GehirnInc/crypt/sha256_crypt" "github.com/GehirnInc/crypt/sha512_crypt" "github.com/alexedwards/argon2id" "github.com/go-chi/render" "github.com/rs/xid" "github.com/sftpgo/sdk" passwordvalidator "github.com/wagslane/go-password-validator" "golang.org/x/crypto/bcrypt" "golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( // SQLiteDataProviderName defines the name for SQLite database provider SQLiteDataProviderName = "sqlite" // PGSQLDataProviderName defines the name for PostgreSQL database provider PGSQLDataProviderName = "postgresql" // MySQLDataProviderName defines the name for MySQL database provider MySQLDataProviderName = "mysql" // BoltDataProviderName defines the name for bbolt key/value store provider BoltDataProviderName = "bolt" // MemoryDataProviderName defines the name for memory provider MemoryDataProviderName = "memory" // CockroachDataProviderName defines the for CockroachDB provider CockroachDataProviderName = "cockroachdb" // DumpVersion defines the version for the dump. // For restore/load we support the current version and the previous one DumpVersion = 17 argonPwdPrefix = "$argon2id$" bcryptPwdPrefix = "$2a$" pbkdf2SHA1Prefix = "$pbkdf2-sha1$" pbkdf2SHA256Prefix = "$pbkdf2-sha256$" pbkdf2SHA512Prefix = "$pbkdf2-sha512$" pbkdf2SHA256B64SaltPrefix = "$pbkdf2-b64salt-sha256$" md5cryptPwdPrefix = "$1$" md5cryptApr1PwdPrefix = "$apr1$" sha256cryptPwdPrefix = "$5$" sha512cryptPwdPrefix = "$6$" yescryptPwdPrefix = "$y$" md5DigestPwdPrefix = "{MD5}" sha256DigestPwdPrefix = "{SHA256}" sha512DigestPwdPrefix = "{SHA512}" trackQuotaDisabledError = "please enable track_quota in your configuration to use this method" operationAdd = "add" operationUpdate = "update" operationDelete = "delete" sqlPrefixValidChars = "abcdefghijklmnopqrstuvwxyz_0123456789" maxHookResponseSize = 1048576 // 1MB ) // Supported algorithms for hashing passwords. // These algorithms can be used when SFTPGo hashes a plain text password const ( HashingAlgoBcrypt = "bcrypt" HashingAlgoArgon2ID = "argon2id" ) // ordering constants const ( OrderASC = "ASC" OrderDESC = "DESC" ) const ( protocolSSH = "SSH" protocolFTP = "FTP" protocolWebDAV = "DAV" protocolHTTP = "HTTP" ) // Dump scopes const ( DumpScopeUsers = "users" DumpScopeFolders = "folders" DumpScopeGroups = "groups" DumpScopeAdmins = "admins" DumpScopeAPIKeys = "api_keys" DumpScopeShares = "shares" DumpScopeActions = "actions" DumpScopeRules = "rules" DumpScopeRoles = "roles" DumpScopeIPLists = "ip_lists" DumpScopeConfigs = "configs" ) const ( fieldUsername = 1 fieldName = 2 fieldIPNet = 3 ) var ( // SupportedProviders defines the supported data providers SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName, BoltDataProviderName, MemoryDataProviderName, CockroachDataProviderName} // ValidPerms defines all the valid permissions for a user ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermCreateDirs, PermRename, PermRenameFiles, PermRenameDirs, PermDelete, PermDeleteFiles, PermDeleteDirs, PermCopy, PermCreateSymlinks, PermChmod, PermChown, PermChtimes} // ValidLoginMethods defines all the valid login methods ValidLoginMethods = []string{SSHLoginMethodPublicKey, LoginMethodPassword, SSHLoginMethodPassword, SSHLoginMethodKeyboardInteractive, SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt, LoginMethodTLSCertificate, LoginMethodTLSCertificateAndPwd} // SSHMultiStepsLoginMethods defines the supported Multi-Step Authentications SSHMultiStepsLoginMethods = []string{SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt} // ErrNoAuthTried defines the error for connection closed before authentication ErrNoAuthTried = errors.New("no auth tried") // ErrNotImplemented defines the error for features not supported for a particular data provider ErrNotImplemented = errors.New("feature not supported with the configured data provider") // ValidProtocols defines all the valid protcols ValidProtocols = []string{protocolSSH, protocolFTP, protocolWebDAV, protocolHTTP} // MFAProtocols defines the supported protocols for multi-factor authentication MFAProtocols = []string{protocolHTTP, protocolSSH, protocolFTP} // ErrNoInitRequired defines the error returned by InitProvider if no inizialization/update is required ErrNoInitRequired = errors.New("the data provider is up to date") // ErrInvalidCredentials defines the error to return if the supplied credentials are invalid ErrInvalidCredentials = errors.New("invalid credentials") // ErrLoginNotAllowedFromIP defines the error to return if login is denied from the current IP ErrLoginNotAllowedFromIP = errors.New("login is not allowed from this IP") // ErrDuplicatedKey occurs when there is a unique key constraint violation ErrDuplicatedKey = errors.New("duplicated key not allowed") // ErrForeignKeyViolated occurs when there is a foreign key constraint violation ErrForeignKeyViolated = errors.New("violates foreign key constraint") errInvalidInput = util.NewValidationError("Invalid input. Slashes (/ ), colons (:), control characters, and reserved system names are not allowed") tz = "" isAdminCreated atomic.Bool validTLSUsernames = []string{string(sdk.TLSUsernameNone), string(sdk.TLSUsernameCN)} config Config provider Provider sqlPlaceholders []string internalHashPwdPrefixes = []string{argonPwdPrefix, bcryptPwdPrefix} hashPwdPrefixes = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix, md5cryptPwdPrefix, md5cryptApr1PwdPrefix, md5DigestPwdPrefix, sha256DigestPwdPrefix, sha512DigestPwdPrefix, sha256cryptPwdPrefix, sha512cryptPwdPrefix, yescryptPwdPrefix} pbkdfPwdPrefixes = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix} pbkdfPwdB64SaltPrefixes = []string{pbkdf2SHA256B64SaltPrefix} unixPwdPrefixes = []string{md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha256cryptPwdPrefix, sha512cryptPwdPrefix, yescryptPwdPrefix} digestPwdPrefixes = []string{md5DigestPwdPrefix, sha256DigestPwdPrefix, sha512DigestPwdPrefix} sharedProviders = []string{PGSQLDataProviderName, MySQLDataProviderName, CockroachDataProviderName} logSender = "dataprovider" sqlTableUsers string sqlTableFolders string sqlTableUsersFoldersMapping string sqlTableAdmins string sqlTableAPIKeys string sqlTableShares string sqlTableSharesGroupsMapping string sqlTableDefenderHosts string sqlTableDefenderEvents string sqlTableActiveTransfers string sqlTableGroups string sqlTableUsersGroupsMapping string sqlTableAdminsGroupsMapping string sqlTableGroupsFoldersMapping string sqlTableSharedSessions string sqlTableEventsActions string sqlTableEventsRules string sqlTableRulesActionsMapping string sqlTableTasks string sqlTableNodes string sqlTableRoles string sqlTableIPLists string sqlTableConfigs string sqlTableSchemaVersion string argon2Params *argon2id.Params lastLoginMinDelay = 10 * time.Minute usernameRegex = regexp.MustCompile("^[a-zA-Z0-9-_.~]+$") tempPath string allowSelfConnections int fnReloadRules FnReloadRules fnRemoveRule FnRemoveRule fnHandleRuleForProviderEvent FnHandleRuleForProviderEvent ) func initSQLTables() { sqlTableUsers = "users" sqlTableFolders = "folders" sqlTableUsersFoldersMapping = "users_folders_mapping" sqlTableAdmins = "admins" sqlTableAPIKeys = "api_keys" sqlTableShares = "shares" sqlTableSharesGroupsMapping = "shares_groups_mapping" sqlTableDefenderHosts = "defender_hosts" sqlTableDefenderEvents = "defender_events" sqlTableActiveTransfers = "active_transfers" sqlTableGroups = "groups" sqlTableUsersGroupsMapping = "users_groups_mapping" sqlTableGroupsFoldersMapping = "groups_folders_mapping" sqlTableAdminsGroupsMapping = "admins_groups_mapping" sqlTableSharedSessions = "shared_sessions" sqlTableEventsActions = "events_actions" sqlTableEventsRules = "events_rules" sqlTableRulesActionsMapping = "rules_actions_mapping" sqlTableTasks = "tasks" sqlTableNodes = "nodes" sqlTableRoles = "roles" sqlTableIPLists = "ip_lists" sqlTableConfigs = "configurations" sqlTableSchemaVersion = "schema_version" } // FnReloadRules defined the callback to reload event rules type FnReloadRules func() // FnRemoveRule defines the callback to remove an event rule type FnRemoveRule func(name string) // FnHandleRuleForProviderEvent define the callback to handle event rules for provider events type FnHandleRuleForProviderEvent func(operation, executor, ip, objectType, objectName, role string, object plugin.Renderer) // SetEventRulesCallbacks sets the event rules callbacks func SetEventRulesCallbacks(reload FnReloadRules, remove FnRemoveRule, handle FnHandleRuleForProviderEvent) { fnReloadRules = reload fnRemoveRule = remove fnHandleRuleForProviderEvent = handle } type schemaVersion struct { Version int } // BcryptOptions defines the options for bcrypt password hashing type BcryptOptions struct { Cost int `json:"cost" mapstructure:"cost"` } // Argon2Options defines the options for argon2 password hashing type Argon2Options struct { Memory uint32 `json:"memory" mapstructure:"memory"` Iterations uint32 `json:"iterations" mapstructure:"iterations"` Parallelism uint8 `json:"parallelism" mapstructure:"parallelism"` } // PasswordHashing defines the configuration for password hashing type PasswordHashing struct { BcryptOptions BcryptOptions `json:"bcrypt_options" mapstructure:"bcrypt_options"` Argon2Options Argon2Options `json:"argon2_options" mapstructure:"argon2_options"` // Algorithm to use for hashing passwords. Available algorithms: argon2id, bcrypt. Default: bcrypt Algo string `json:"algo" mapstructure:"algo"` } // PasswordValidationRules defines the password validation rules type PasswordValidationRules struct { // MinEntropy defines the minimum password entropy. // 0 means disabled, any password will be accepted. // Take a look at the following link for more details // https://github.com/wagslane/go-password-validator#what-entropy-value-should-i-use MinEntropy float64 `json:"min_entropy" mapstructure:"min_entropy"` } // PasswordValidation defines the password validation rules for admins and protocol users type PasswordValidation struct { // Password validation rules for SFTPGo admin users Admins PasswordValidationRules `json:"admins" mapstructure:"admins"` // Password validation rules for SFTPGo protocol users Users PasswordValidationRules `json:"users" mapstructure:"users"` } type wrappedFolder struct { Folder vfs.BaseVirtualFolder } func (w *wrappedFolder) RenderAsJSON(reload bool) ([]byte, error) { if reload { folder, err := provider.getFolderByName(w.Folder.Name) if err != nil { providerLog(logger.LevelError, "unable to reload folder before rendering as json: %v", err) return nil, err } folder.PrepareForRendering() return json.Marshal(folder) } w.Folder.PrepareForRendering() return json.Marshal(w.Folder) } // ObjectsActions defines the action to execute on user create, update, delete for the specified objects type ObjectsActions struct { // Valid values are add, update, delete. Empty slice to disable ExecuteOn []string `json:"execute_on" mapstructure:"execute_on"` // Valid values are user, admin, api_key ExecuteFor []string `json:"execute_for" mapstructure:"execute_for"` // Absolute path to an external program or an HTTP URL Hook string `json:"hook" mapstructure:"hook"` } // ProviderStatus defines the provider status type ProviderStatus struct { Driver string `json:"driver"` IsActive bool `json:"is_active"` Error string `json:"error"` } // Config defines the provider configuration type Config struct { // Driver name, must be one of the SupportedProviders Driver string `json:"driver" mapstructure:"driver"` // Database name. For driver sqlite this can be the database name relative to the config dir // or the absolute path to the SQLite database. Name string `json:"name" mapstructure:"name"` // Database host. For postgresql and cockroachdb driver you can specify multiple hosts separated by commas Host string `json:"host" mapstructure:"host"` // Database port Port int `json:"port" mapstructure:"port"` // Database username Username string `json:"username" mapstructure:"username"` // Database password Password string `json:"password" mapstructure:"password"` // Used for drivers mysql and postgresql. // 0 disable SSL/TLS connections. // 1 require ssl. // 2 set ssl mode to verify-ca for driver postgresql and skip-verify for driver mysql. // 3 set ssl mode to verify-full for driver postgresql and preferred for driver mysql. SSLMode int `json:"sslmode" mapstructure:"sslmode"` // Used for drivers mysql, postgresql and cockroachdb. Set to true to disable SNI DisableSNI bool `json:"disable_sni" mapstructure:"disable_sni"` // TargetSessionAttrs is a postgresql and cockroachdb specific option. // It determines whether the session must have certain properties to be acceptable. // It's typically used in combination with multiple host names to select the first // acceptable alternative among several hosts TargetSessionAttrs string `json:"target_session_attrs" mapstructure:"target_session_attrs"` // Path to the root certificate authority used to verify that the server certificate was signed by a trusted CA RootCert string `json:"root_cert" mapstructure:"root_cert"` // Path to the client certificate for two-way TLS authentication ClientCert string `json:"client_cert" mapstructure:"client_cert"` // Path to the client key for two-way TLS authentication ClientKey string `json:"client_key" mapstructure:"client_key"` // Custom database connection string. // If not empty this connection string will be used instead of build one using the previous parameters ConnectionString string `json:"connection_string" mapstructure:"connection_string"` // prefix for SQL tables SQLTablesPrefix string `json:"sql_tables_prefix" mapstructure:"sql_tables_prefix"` // Set the preferred way to track users quota between the following choices: // 0, disable quota tracking. REST API to scan user dir and update quota will do nothing // 1, quota is updated each time a user upload or delete a file even if the user has no quota restrictions // 2, quota is updated each time a user upload or delete a file but only for users with quota restrictions // and for virtual folders. // With this configuration the "quota scan" REST API can still be used to periodically update space usage // for users without quota restrictions TrackQuota int `json:"track_quota" mapstructure:"track_quota"` // Sets the maximum number of open connections for mysql and postgresql driver. // Default 0 (unlimited) PoolSize int `json:"pool_size" mapstructure:"pool_size"` // Users default base directory. // If no home dir is defined while adding a new user, and this value is // a valid absolute path, then the user home dir will be automatically // defined as the path obtained joining the base dir and the username UsersBaseDir string `json:"users_base_dir" mapstructure:"users_base_dir"` // Actions to execute on objects add, update, delete. // The supported objects are user, admin, api_key. // Update action will not be fired for internal updates such as the last login or the user quota fields. Actions ObjectsActions `json:"actions" mapstructure:"actions"` // Absolute path to an external program or an HTTP URL to invoke for users authentication. // Leave empty to use builtin authentication. // If the authentication succeed the user will be automatically added/updated inside the defined data provider. // Actions defined for user added/updated will not be executed in this case. // This method is slower than built-in authentication methods, but it's very flexible as anyone can // easily write his own authentication hooks. ExternalAuthHook string `json:"external_auth_hook" mapstructure:"external_auth_hook"` // ExternalAuthScope defines the scope for the external authentication hook. // - 0 means all supported authentication scopes, the external hook will be executed for password, // public key, keyboard interactive authentication and TLS certificates // - 1 means passwords only // - 2 means public keys only // - 4 means keyboard interactive only // - 8 means TLS certificates only // you can combine the scopes, for example 3 means password and public key, 5 password and keyboard // interactive and so on ExternalAuthScope int `json:"external_auth_scope" mapstructure:"external_auth_scope"` // Absolute path to an external program or an HTTP URL to invoke just before the user login. // This program/URL allows to modify or create the user trying to login. // It is useful if you have users with dynamic fields to update just before the login. // Please note that if you want to create a new user, the pre-login hook response must // include all the mandatory user fields. // // The pre-login hook must finish within 30 seconds. // // If an error happens while executing the "PreLoginHook" then login will be denied. // PreLoginHook and ExternalAuthHook are mutally exclusive. // Leave empty to disable. PreLoginHook string `json:"pre_login_hook" mapstructure:"pre_login_hook"` // Absolute path to an external program or an HTTP URL to invoke after the user login. // Based on the configured scope you can choose if notify failed or successful logins // or both PostLoginHook string `json:"post_login_hook" mapstructure:"post_login_hook"` // PostLoginScope defines the scope for the post-login hook. // - 0 means notify both failed and successful logins // - 1 means notify failed logins // - 2 means notify successful logins PostLoginScope int `json:"post_login_scope" mapstructure:"post_login_scope"` // Absolute path to an external program or an HTTP URL to invoke just before password // authentication. This hook allows you to externally check the provided password, // its main use case is to allow to easily support things like password+OTP for protocols // without keyboard interactive support such as FTP and WebDAV. You can ask your users // to login using a string consisting of a fixed password and a One Time Token, you // can verify the token inside the hook and ask to SFTPGo to verify the fixed part. CheckPasswordHook string `json:"check_password_hook" mapstructure:"check_password_hook"` // CheckPasswordScope defines the scope for the check password hook. // - 0 means all protocols // - 1 means SSH // - 2 means FTP // - 4 means WebDAV // you can combine the scopes, for example 6 means FTP and WebDAV CheckPasswordScope int `json:"check_password_scope" mapstructure:"check_password_scope"` // Defines how the database will be initialized/updated: // - 0 means automatically // - 1 means manually using the initprovider sub-command UpdateMode int `json:"update_mode" mapstructure:"update_mode"` // PasswordHashing defines the configuration for password hashing PasswordHashing PasswordHashing `json:"password_hashing" mapstructure:"password_hashing"` // PasswordValidation defines the password validation rules PasswordValidation PasswordValidation `json:"password_validation" mapstructure:"password_validation"` // Verifying argon2 passwords has a high memory and computational cost, // by enabling, in memory, password caching you reduce this cost. PasswordCaching bool `json:"password_caching" mapstructure:"password_caching"` // DelayedQuotaUpdate defines the number of seconds to accumulate quota updates. // If there are a lot of close uploads, accumulating quota updates can save you many // queries to the data provider. // If you want to track quotas, a scheduled quota update is recommended in any case, the stored // quota size may be incorrect for several reasons, such as an unexpected shutdown, temporary provider // failures, file copied outside of SFTPGo, and so on. // 0 means immediate quota update. DelayedQuotaUpdate int `json:"delayed_quota_update" mapstructure:"delayed_quota_update"` // If enabled, a default admin user with username "admin" and password "password" will be created // on first start. // You can also create the first admin user by using the web interface or by loading initial data. CreateDefaultAdmin bool `json:"create_default_admin" mapstructure:"create_default_admin"` // Rules for usernames and folder names: // - 0 means no rules // - 1 means you can use any UTF-8 character. The names are used in URIs for REST API and Web admin. // By default only unreserved URI characters are allowed: ALPHA / DIGIT / "-" / "." / "_" / "~". // - 2 means names are converted to lowercase before saving/matching and so case // insensitive matching is possible // - 4 means trimming trailing and leading white spaces before saving/matching // Rules can be combined, for example 3 means both converting to lowercase and allowing any UTF-8 character. // Enabling these options for existing installations could be backward incompatible, some users // could be unable to login, for example existing users with mixed cases in their usernames. // You have to ensure that all existing users respect the defined rules. NamingRules int `json:"naming_rules" mapstructure:"naming_rules"` // If the data provider is shared across multiple SFTPGo instances, set this parameter to 1. // MySQL, PostgreSQL and CockroachDB can be shared, this setting is ignored for other data // providers. For shared data providers, SFTPGo periodically reloads the latest updated users, // based on the "updated_at" field, and updates its internal caches if users are updated from // a different instance. This check, if enabled, is executed every 10 minutes. // For shared data providers, active transfers are persisted in the database and thus // quota checks between ongoing transfers will work cross multiple instances IsShared int `json:"is_shared" mapstructure:"is_shared"` // Node defines the configuration for this cluster node. // Ignored if the provider is not shared/shareable Node NodeConfig `json:"node" mapstructure:"node"` // Path to the backup directory. This can be an absolute path or a path relative to the config dir BackupsPath string `json:"backups_path" mapstructure:"backups_path"` } // GetShared returns the provider share mode. // This method is called before the provider is initialized func (c *Config) GetShared() int { if !slices.Contains(sharedProviders, c.Driver) { return 0 } return c.IsShared } func (c *Config) convertName(name string) string { if c.NamingRules <= 1 { return name } if c.NamingRules&2 != 0 { name = strings.ToLower(name) } if c.NamingRules&4 != 0 { name = strings.TrimSpace(name) } return name } // IsDefenderSupported returns true if the configured provider supports the defender func (c *Config) IsDefenderSupported() bool { switch c.Driver { case MySQLDataProviderName, PGSQLDataProviderName, CockroachDataProviderName: return true default: return false } } func (c *Config) requireCustomTLSForMySQL() bool { if config.DisableSNI { return config.SSLMode != 0 } if config.RootCert != "" && util.IsFileInputValid(config.RootCert) { return config.SSLMode != 0 } if config.ClientCert != "" && config.ClientKey != "" && util.IsFileInputValid(config.ClientCert) && util.IsFileInputValid(config.ClientKey) { return config.SSLMode != 0 } return false } func (c *Config) doBackup() (string, error) { now := time.Now().UTC() outputFile := filepath.Join(c.BackupsPath, fmt.Sprintf("backup_%s_%d.json", now.Weekday(), now.Hour())) providerLog(logger.LevelDebug, "starting backup to file %q", outputFile) err := os.MkdirAll(filepath.Dir(outputFile), 0700) if err != nil { providerLog(logger.LevelError, "unable to create backup dir %q: %v", outputFile, err) return outputFile, fmt.Errorf("unable to create backup dir: %w", err) } backup, err := DumpData(nil) if err != nil { providerLog(logger.LevelError, "unable to execute backup: %v", err) return outputFile, fmt.Errorf("unable to dump backup data: %w", err) } dump, err := json.Marshal(backup) if err != nil { providerLog(logger.LevelError, "unable to marshal backup as JSON: %v", err) return outputFile, fmt.Errorf("unable to marshal backup data as JSON: %w", err) } err = os.WriteFile(outputFile, dump, 0600) if err != nil { providerLog(logger.LevelError, "unable to save backup: %v", err) return outputFile, fmt.Errorf("unable to save backup: %w", err) } providerLog(logger.LevelDebug, "backup saved to %q", outputFile) return outputFile, nil } // SetTZ sets the configured timezone. func SetTZ(val string) { tz = val } // UseLocalTime returns true if local time should be used instead of UTC. func UseLocalTime() bool { return tz == "local" } // ExecuteBackup executes a backup func ExecuteBackup() (string, error) { return config.doBackup() } // ConvertName converts the given name based on the configured rules func ConvertName(name string) string { return config.convertName(name) } // ActiveTransfer defines an active protocol transfer type ActiveTransfer struct { ID int64 Type int ConnID string Username string FolderName string IP string TruncatedSize int64 CurrentULSize int64 CurrentDLSize int64 CreatedAt int64 UpdatedAt int64 } // TransferQuota stores the allowed transfer quota fields type TransferQuota struct { ULSize int64 DLSize int64 TotalSize int64 AllowedULSize int64 AllowedDLSize int64 AllowedTotalSize int64 } // HasSizeLimits returns true if any size limit is set func (q *TransferQuota) HasSizeLimits() bool { return q.AllowedDLSize > 0 || q.AllowedULSize > 0 || q.AllowedTotalSize > 0 } // HasUploadSpace returns true if there is transfer upload space available func (q *TransferQuota) HasUploadSpace() bool { if q.TotalSize <= 0 && q.ULSize <= 0 { return true } if q.TotalSize > 0 { return q.AllowedTotalSize > 0 } return q.AllowedULSize > 0 } // HasDownloadSpace returns true if there is transfer download space available func (q *TransferQuota) HasDownloadSpace() bool { if q.TotalSize <= 0 && q.DLSize <= 0 { return true } if q.TotalSize > 0 { return q.AllowedTotalSize > 0 } return q.AllowedDLSize > 0 } // DefenderEntry defines a defender entry type DefenderEntry struct { ID int64 `json:"-"` IP string `json:"ip"` Score int `json:"score,omitempty"` BanTime time.Time `json:"ban_time,omitempty"` } // GetID returns an unique ID for a defender entry func (d *DefenderEntry) GetID() string { return hex.EncodeToString([]byte(d.IP)) } // GetBanTime returns the ban time for a defender entry as string func (d *DefenderEntry) GetBanTime() string { if d.BanTime.IsZero() { return "" } return d.BanTime.UTC().Format(time.RFC3339) } // MarshalJSON returns the JSON encoding of a DefenderEntry. func (d *DefenderEntry) MarshalJSON() ([]byte, error) { return json.Marshal(&struct { ID string `json:"id"` IP string `json:"ip"` Score int `json:"score,omitempty"` BanTime string `json:"ban_time,omitempty"` }{ ID: d.GetID(), IP: d.IP, Score: d.Score, BanTime: d.GetBanTime(), }) } // BackupData defines the structure for the backup/restore files type BackupData struct { Users []User `json:"users"` Groups []Group `json:"groups"` Folders []vfs.BaseVirtualFolder `json:"folders"` Admins []Admin `json:"admins"` APIKeys []APIKey `json:"api_keys"` Shares []Share `json:"shares"` EventActions []BaseEventAction `json:"event_actions"` EventRules []EventRule `json:"event_rules"` Roles []Role `json:"roles"` IPLists []IPListEntry `json:"ip_lists"` Configs *Configs `json:"configs"` Version int `json:"version"` } // HasFolder returns true if the folder with the given name is included func (d *BackupData) HasFolder(name string) bool { for _, folder := range d.Folders { if folder.Name == name { return true } } return false } type checkPasswordRequest struct { Username string `json:"username"` IP string `json:"ip"` Password string `json:"password"` Protocol string `json:"protocol"` } type checkPasswordResponse struct { // 0 KO, 1 OK, 2 partial success, -1 not executed Status int `json:"status"` // for status = 2 this is the password to check against the one stored // inside the SFTPGo data provider ToVerify string `json:"to_verify"` } // GetQuotaTracking returns the configured mode for user's quota tracking func GetQuotaTracking() int { return config.TrackQuota } // HasUsersBaseDir returns true if users base dir is set func HasUsersBaseDir() bool { return config.UsersBaseDir != "" } // Provider defines the interface that data providers must implement. type Provider interface { validateUserAndPass(username, password, ip, protocol string) (User, error) validateUserAndPubKey(username string, pubKey []byte, isSSHCert bool) (User, string, error) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error getUsedQuota(username string) (int, int64, int64, int64, error) userExists(username, role string) (User, error) addUser(user *User) error updateUser(user *User) error deleteUser(user User, softDelete bool) error updateUserPassword(username, password string) error // used internally when converting passwords from other hash getUsers(limit int, offset int, order, role string) ([]User, error) dumpUsers() ([]User, error) getRecentlyUpdatedUsers(after int64) ([]User, error) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) updateLastLogin(username string) error updateAdminLastLogin(username string) error setUpdatedAt(username string) getAdminSignature(username string) (string, error) getUserSignature(username string) (string, error) getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) getFolderByName(name string) (vfs.BaseVirtualFolder, error) addFolder(folder *vfs.BaseVirtualFolder) error updateFolder(folder *vfs.BaseVirtualFolder) error deleteFolder(folder vfs.BaseVirtualFolder) error updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error getUsedFolderQuota(name string) (int, int64, error) dumpFolders() ([]vfs.BaseVirtualFolder, error) getGroups(limit, offset int, order string, minimal bool) ([]Group, error) getGroupsWithNames(names []string) ([]Group, error) getUsersInGroups(names []string) ([]string, error) groupExists(name string) (Group, error) addGroup(group *Group) error updateGroup(group *Group) error deleteGroup(group Group) error dumpGroups() ([]Group, error) adminExists(username string) (Admin, error) addAdmin(admin *Admin) error updateAdmin(admin *Admin) error deleteAdmin(admin Admin) error getAdmins(limit int, offset int, order string) ([]Admin, error) dumpAdmins() ([]Admin, error) validateAdminAndPass(username, password, ip string) (Admin, error) apiKeyExists(keyID string) (APIKey, error) addAPIKey(apiKey *APIKey) error updateAPIKey(apiKey *APIKey) error deleteAPIKey(apiKey APIKey) error getAPIKeys(limit int, offset int, order string) ([]APIKey, error) dumpAPIKeys() ([]APIKey, error) updateAPIKeyLastUse(keyID string) error shareExists(shareID, username string) (Share, error) addShare(share *Share) error updateShare(share *Share) error deleteShare(share Share) error getShares(limit int, offset int, order, username string) ([]Share, error) dumpShares() ([]Share, error) updateShareLastUse(shareID string, numTokens int) error getDefenderHosts(from int64, limit int) ([]DefenderEntry, error) getDefenderHostByIP(ip string, from int64) (DefenderEntry, error) isDefenderHostBanned(ip string) (DefenderEntry, error) updateDefenderBanTime(ip string, minutes int) error deleteDefenderHost(ip string) error addDefenderEvent(ip string, score int) error setDefenderBanTime(ip string, banTime int64) error cleanupDefender(from int64) error addActiveTransfer(transfer ActiveTransfer) error updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error removeActiveTransfer(transferID int64, connectionID string) error cleanupActiveTransfers(before time.Time) error getActiveTransfers(from time.Time) ([]ActiveTransfer, error) addSharedSession(session Session) error deleteSharedSession(key string, sessionType SessionType) error getSharedSession(key string, sessionType SessionType) (Session, error) cleanupSharedSessions(sessionType SessionType, before int64) error getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) dumpEventActions() ([]BaseEventAction, error) eventActionExists(name string) (BaseEventAction, error) addEventAction(action *BaseEventAction) error updateEventAction(action *BaseEventAction) error deleteEventAction(action BaseEventAction) error getEventRules(limit, offset int, order string) ([]EventRule, error) dumpEventRules() ([]EventRule, error) getRecentlyUpdatedRules(after int64) ([]EventRule, error) eventRuleExists(name string) (EventRule, error) addEventRule(rule *EventRule) error updateEventRule(rule *EventRule) error deleteEventRule(rule EventRule, softDelete bool) error getTaskByName(name string) (Task, error) addTask(name string) error updateTask(name string, version int64) error updateTaskTimestamp(name string) error setFirstDownloadTimestamp(username string) error setFirstUploadTimestamp(username string) error addNode() error getNodeByName(name string) (Node, error) getNodes() ([]Node, error) updateNodeTimestamp() error cleanupNodes() error roleExists(name string) (Role, error) addRole(role *Role) error updateRole(role *Role) error deleteRole(role Role) error getRoles(limit int, offset int, order string, minimal bool) ([]Role, error) dumpRoles() ([]Role, error) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) addIPListEntry(entry *IPListEntry) error updateIPListEntry(entry *IPListEntry) error deleteIPListEntry(entry IPListEntry, softDelete bool) error getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) getRecentlyUpdatedIPListEntries(after int64) ([]IPListEntry, error) dumpIPListEntries() ([]IPListEntry, error) countIPListEntries(listType IPListType) (int64, error) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) getConfigs() (Configs, error) setConfigs(configs *Configs) error checkAvailability() error close() error reloadConfig() error initializeDatabase() error migrateDatabase() error revertDatabase(targetVersion int) error resetDatabase() error } // SetAllowSelfConnections sets the desired behaviour for self connections func SetAllowSelfConnections(value int) { allowSelfConnections = value } // SetTempPath sets the path for temporary files func SetTempPath(fsPath string) { tempPath = fsPath } func checkSharedMode() { if !slices.Contains(sharedProviders, config.Driver) { config.IsShared = 0 } } // Initialize the data provider. // An error is returned if the configured driver is invalid or if the data provider cannot be initialized func Initialize(cnf Config, basePath string, checkAdmins bool) error { config = cnf checkSharedMode() config.Actions.ExecuteOn = util.RemoveDuplicates(config.Actions.ExecuteOn, true) config.Actions.ExecuteFor = util.RemoveDuplicates(config.Actions.ExecuteFor, true) cnf.BackupsPath = getConfigPath(cnf.BackupsPath, basePath) if cnf.BackupsPath == "" { return fmt.Errorf("required directory is invalid, backup path %q", cnf.BackupsPath) } absoluteBackupPath, err := util.GetAbsolutePath(cnf.BackupsPath) if err != nil { return fmt.Errorf("unable to get absolute backup path: %w", err) } config.BackupsPath = absoluteBackupPath if err := initializeHashingAlgo(&cnf); err != nil { return err } if err := validateHooks(); err != nil { return err } if err := createProvider(basePath); err != nil { return err } if err := checkDatabase(checkAdmins); err != nil { return err } admins, err := provider.getAdmins(1, 0, OrderASC) if err != nil { return err } isAdminCreated.Store(len(admins) > 0) if err := config.Node.validate(); err != nil { return err } delayedQuotaUpdater.start() if currentNode != nil { config.BackupsPath = filepath.Join(config.BackupsPath, currentNode.Name) } providerLog(logger.LevelDebug, "absolute backup path %q", config.BackupsPath) return startScheduler() } func checkDatabase(checkAdmins bool) error { if config.UpdateMode == 0 { err := provider.initializeDatabase() if err != nil && err != ErrNoInitRequired { logger.WarnToConsole("unable to initialize data provider: %v", err) providerLog(logger.LevelError, "unable to initialize data provider: %v", err) return err } if err == nil { logger.DebugToConsole("data provider successfully initialized") providerLog(logger.LevelInfo, "data provider successfully initialized") } err = provider.migrateDatabase() if err != nil && err != ErrNoInitRequired { providerLog(logger.LevelError, "database migration error: %v", err) return err } if checkAdmins && config.CreateDefaultAdmin { err = checkDefaultAdmin() if err != nil { providerLog(logger.LevelError, "erro checking the default admin: %v", err) return err } } } else { providerLog(logger.LevelInfo, "database initialization/migration skipped, manual mode is configured") } return nil } func validateHooks() error { var hooks []string if config.PreLoginHook != "" && !strings.HasPrefix(config.PreLoginHook, "http") { hooks = append(hooks, config.PreLoginHook) } if config.ExternalAuthHook != "" && !strings.HasPrefix(config.ExternalAuthHook, "http") { hooks = append(hooks, config.ExternalAuthHook) } if config.PostLoginHook != "" && !strings.HasPrefix(config.PostLoginHook, "http") { hooks = append(hooks, config.PostLoginHook) } if config.CheckPasswordHook != "" && !strings.HasPrefix(config.CheckPasswordHook, "http") { hooks = append(hooks, config.CheckPasswordHook) } for _, hook := range hooks { if !filepath.IsAbs(hook) { return fmt.Errorf("invalid hook: %q must be an absolute path", hook) } _, err := os.Stat(hook) if err != nil { providerLog(logger.LevelError, "invalid hook: %v", err) return err } } return nil } // GetBackupsPath returns the normalized backups path func GetBackupsPath() string { return config.BackupsPath } // GetProviderFromValue returns the FilesystemProvider matching the specified value. // If no match is found LocalFilesystemProvider is returned. func GetProviderFromValue(value string) sdk.FilesystemProvider { val, err := strconv.Atoi(value) if err != nil { return sdk.LocalFilesystemProvider } result := sdk.FilesystemProvider(val) if sdk.IsProviderSupported(result) { return result } return sdk.LocalFilesystemProvider } func initializeHashingAlgo(cnf *Config) error { parallelism := cnf.PasswordHashing.Argon2Options.Parallelism if parallelism == 0 { parallelism = uint8(runtime.NumCPU()) } argon2Params = &argon2id.Params{ Memory: cnf.PasswordHashing.Argon2Options.Memory, Iterations: cnf.PasswordHashing.Argon2Options.Iterations, Parallelism: parallelism, SaltLength: 16, KeyLength: 32, } if config.PasswordHashing.Algo == HashingAlgoBcrypt { if config.PasswordHashing.BcryptOptions.Cost > bcrypt.MaxCost { err := fmt.Errorf("invalid bcrypt cost %v, max allowed %v", config.PasswordHashing.BcryptOptions.Cost, bcrypt.MaxCost) logger.WarnToConsole("Unable to initialize data provider: %v", err) providerLog(logger.LevelError, "Unable to initialize data provider: %v", err) return err } } return nil } func validateSQLTablesPrefix() error { initSQLTables() if config.SQLTablesPrefix != "" { for _, char := range config.SQLTablesPrefix { if !strings.Contains(sqlPrefixValidChars, strings.ToLower(string(char))) { return errors.New("invalid sql_tables_prefix only chars in range 'a..z', 'A..Z', '0-9' and '_' are allowed") } } sqlTableUsers = config.SQLTablesPrefix + sqlTableUsers sqlTableFolders = config.SQLTablesPrefix + sqlTableFolders sqlTableUsersFoldersMapping = config.SQLTablesPrefix + sqlTableUsersFoldersMapping sqlTableAdmins = config.SQLTablesPrefix + sqlTableAdmins sqlTableAPIKeys = config.SQLTablesPrefix + sqlTableAPIKeys sqlTableShares = config.SQLTablesPrefix + sqlTableShares sqlTableSharesGroupsMapping = config.SQLTablesPrefix + sqlTableSharesGroupsMapping sqlTableDefenderEvents = config.SQLTablesPrefix + sqlTableDefenderEvents sqlTableDefenderHosts = config.SQLTablesPrefix + sqlTableDefenderHosts sqlTableActiveTransfers = config.SQLTablesPrefix + sqlTableActiveTransfers sqlTableGroups = config.SQLTablesPrefix + sqlTableGroups sqlTableUsersGroupsMapping = config.SQLTablesPrefix + sqlTableUsersGroupsMapping sqlTableAdminsGroupsMapping = config.SQLTablesPrefix + sqlTableAdminsGroupsMapping sqlTableGroupsFoldersMapping = config.SQLTablesPrefix + sqlTableGroupsFoldersMapping sqlTableSharedSessions = config.SQLTablesPrefix + sqlTableSharedSessions sqlTableEventsActions = config.SQLTablesPrefix + sqlTableEventsActions sqlTableEventsRules = config.SQLTablesPrefix + sqlTableEventsRules sqlTableRulesActionsMapping = config.SQLTablesPrefix + sqlTableRulesActionsMapping sqlTableTasks = config.SQLTablesPrefix + sqlTableTasks sqlTableNodes = config.SQLTablesPrefix + sqlTableNodes sqlTableRoles = config.SQLTablesPrefix + sqlTableRoles sqlTableIPLists = config.SQLTablesPrefix + sqlTableIPLists sqlTableConfigs = config.SQLTablesPrefix + sqlTableConfigs sqlTableSchemaVersion = config.SQLTablesPrefix + sqlTableSchemaVersion providerLog(logger.LevelDebug, "sql table for users %q, folders %q users folders mapping %q admins %q "+ "api keys %q shares %q defender hosts %q defender events %q transfers %q groups %q "+ "users groups mapping %q admins groups mapping %q groups folders mapping %q shared sessions %q "+ "schema version %q events actions %q events rules %q rules actions mapping %q tasks %q nodes %q roles %q"+ "ip lists %q share groups mapping %q configs %q", sqlTableUsers, sqlTableFolders, sqlTableUsersFoldersMapping, sqlTableAdmins, sqlTableAPIKeys, sqlTableShares, sqlTableDefenderHosts, sqlTableDefenderEvents, sqlTableActiveTransfers, sqlTableGroups, sqlTableUsersGroupsMapping, sqlTableAdminsGroupsMapping, sqlTableGroupsFoldersMapping, sqlTableSharedSessions, sqlTableSchemaVersion, sqlTableEventsActions, sqlTableEventsRules, sqlTableRulesActionsMapping, sqlTableTasks, sqlTableNodes, sqlTableRoles, sqlTableIPLists, sqlTableSharesGroupsMapping, sqlTableConfigs) } return nil } func checkDefaultAdmin() error { admins, err := provider.getAdmins(1, 0, OrderASC) if err != nil { return err } if len(admins) > 0 { return nil } logger.Debug(logSender, "", "no admins found, try to create the default one") // we need to create the default admin admin := &Admin{} if err := admin.setFromEnv(); err != nil { return err } return provider.addAdmin(admin) } // InitializeDatabase creates the initial database structure func InitializeDatabase(cnf Config, basePath string) error { config = cnf if err := initializeHashingAlgo(&cnf); err != nil { return err } err := createProvider(basePath) if err != nil { return err } err = provider.initializeDatabase() if err != nil && err != ErrNoInitRequired { return err } return provider.migrateDatabase() } // RevertDatabase restores schema and/or data to a previous version func RevertDatabase(cnf Config, basePath string, targetVersion int) error { config = cnf err := createProvider(basePath) if err != nil { return err } err = provider.initializeDatabase() if err != nil && err != ErrNoInitRequired { return err } return provider.revertDatabase(targetVersion) } // ResetDatabase restores schema and/or data to a previous version func ResetDatabase(cnf Config, basePath string) error { config = cnf if err := createProvider(basePath); err != nil { return err } return provider.resetDatabase() } // CheckAdminAndPass validates the given admin and password connecting from ip func CheckAdminAndPass(username, password, ip string) (Admin, error) { username = config.convertName(username) return provider.validateAdminAndPass(username, password, ip) } // CheckCachedUserCredentials checks the credentials for a cached user func CheckCachedUserCredentials(user *CachedUser, password, ip, loginMethod, protocol string, tlsCert *x509.Certificate) (*CachedUser, *User, error) { if !user.User.skipExternalAuth() && isExternalAuthConfigured(loginMethod) { u, _, err := CheckCompositeCredentials(user.User.Username, password, ip, loginMethod, protocol, tlsCert) if err != nil { return nil, nil, err } webDAVUsersCache.swap(&u, password) cu, _ := webDAVUsersCache.get(u.Username) return cu, &u, nil } if err := user.User.CheckLoginConditions(); err != nil { return user, nil, err } if loginMethod == LoginMethodPassword && user.User.Filters.IsAnonymous { return user, nil, nil } if loginMethod != LoginMethodPassword { _, err := checkUserAndTLSCertificate(&user.User, protocol, tlsCert) if err != nil { return user, nil, err } if loginMethod == LoginMethodTLSCertificate { if !user.User.IsLoginMethodAllowed(LoginMethodTLSCertificate, protocol) { return user, nil, fmt.Errorf("certificate login method is not allowed for user %q", user.User.Username) } return user, nil, nil } } if password == "" { return user, nil, ErrInvalidCredentials } if user.Password != "" { if password == user.Password { return user, nil, nil } } else { if ok, _ := isPasswordOK(&user.User, password); ok { return user, nil, nil } } return user, nil, ErrInvalidCredentials } // CheckCompositeCredentials checks multiple credentials. // WebDAV users can send both a password and a TLS certificate within the same request func CheckCompositeCredentials(username, password, ip, loginMethod, protocol string, tlsCert *x509.Certificate) (User, string, error) { username = config.convertName(username) if loginMethod == LoginMethodPassword { user, err := CheckUserAndPass(username, password, ip, protocol) return user, loginMethod, err } user, err := CheckUserBeforeTLSAuth(username, ip, protocol, tlsCert) if err != nil { return user, loginMethod, err } if !user.IsTLSVerificationEnabled() { // for backward compatibility with 2.0.x we only check the password and change the login method here // in future updates we have to return an error user, err := CheckUserAndPass(username, password, ip, protocol) return user, LoginMethodPassword, err } user, err = checkUserAndTLSCertificate(&user, protocol, tlsCert) if err != nil { return user, loginMethod, err } if loginMethod == LoginMethodTLSCertificate && !user.IsLoginMethodAllowed(LoginMethodTLSCertificate, protocol) { return user, loginMethod, fmt.Errorf("certificate login method is not allowed for user %q", user.Username) } if loginMethod == LoginMethodTLSCertificateAndPwd { if plugin.Handler.HasAuthScope(plugin.AuthScopePassword) { user, err = doPluginAuth(username, password, nil, ip, protocol, nil, plugin.AuthScopePassword) } else if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) { user, err = doExternalAuth(username, password, nil, "", ip, protocol, nil) } else if config.PreLoginHook != "" { user, err = executePreLoginHook(username, LoginMethodPassword, ip, protocol, nil) } if err != nil { return user, loginMethod, err } user, err = checkUserAndPass(&user, password, ip, protocol) } return user, loginMethod, err } // CheckUserBeforeTLSAuth checks if a user exits before trying mutual TLS func CheckUserBeforeTLSAuth(username, ip, protocol string, tlsCert *x509.Certificate) (User, error) { username = config.convertName(username) if plugin.Handler.HasAuthScope(plugin.AuthScopeTLSCertificate) { user, err := doPluginAuth(username, "", nil, ip, protocol, tlsCert, plugin.AuthScopeTLSCertificate) if err != nil { return user, err } err = user.LoadAndApplyGroupSettings() return user, err } if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&8 != 0) { user, err := doExternalAuth(username, "", nil, "", ip, protocol, tlsCert) if err != nil { return user, err } err = user.LoadAndApplyGroupSettings() return user, err } if config.PreLoginHook != "" { user, err := executePreLoginHook(username, LoginMethodTLSCertificate, ip, protocol, nil) if err != nil { return user, err } err = user.LoadAndApplyGroupSettings() return user, err } user, err := UserExists(username, "") if err != nil { return user, err } err = user.LoadAndApplyGroupSettings() return user, err } // CheckUserAndTLSCert returns the SFTPGo user with the given username and check if the // given TLS certificate allow authentication without password func CheckUserAndTLSCert(username, ip, protocol string, tlsCert *x509.Certificate) (User, error) { username = config.convertName(username) if plugin.Handler.HasAuthScope(plugin.AuthScopeTLSCertificate) { user, err := doPluginAuth(username, "", nil, ip, protocol, tlsCert, plugin.AuthScopeTLSCertificate) if err != nil { return user, err } return checkUserAndTLSCertificate(&user, protocol, tlsCert) } if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&8 != 0) { user, err := doExternalAuth(username, "", nil, "", ip, protocol, tlsCert) if err != nil { return user, err } return checkUserAndTLSCertificate(&user, protocol, tlsCert) } if config.PreLoginHook != "" { user, err := executePreLoginHook(username, LoginMethodTLSCertificate, ip, protocol, nil) if err != nil { return user, err } return checkUserAndTLSCertificate(&user, protocol, tlsCert) } return provider.validateUserAndTLSCert(username, protocol, tlsCert) } // CheckUserAndPass retrieves the SFTPGo user with the given username and password if a match is found or an error func CheckUserAndPass(username, password, ip, protocol string) (User, error) { username = config.convertName(username) if plugin.Handler.HasAuthScope(plugin.AuthScopePassword) { user, err := doPluginAuth(username, password, nil, ip, protocol, nil, plugin.AuthScopePassword) if err != nil { return user, err } return checkUserAndPass(&user, password, ip, protocol) } if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) { user, err := doExternalAuth(username, password, nil, "", ip, protocol, nil) if err != nil { return user, err } return checkUserAndPass(&user, password, ip, protocol) } if config.PreLoginHook != "" { user, err := executePreLoginHook(username, LoginMethodPassword, ip, protocol, nil) if err != nil { return user, err } return checkUserAndPass(&user, password, ip, protocol) } return provider.validateUserAndPass(username, password, ip, protocol) } // CheckUserAndPubKey retrieves the SFTP user with the given username and public key if a match is found or an error func CheckUserAndPubKey(username string, pubKey []byte, ip, protocol string, isSSHCert bool) (User, string, error) { username = config.convertName(username) if plugin.Handler.HasAuthScope(plugin.AuthScopePublicKey) { user, err := doPluginAuth(username, "", pubKey, ip, protocol, nil, plugin.AuthScopePublicKey) if err != nil { return user, "", err } return checkUserAndPubKey(&user, pubKey, isSSHCert) } if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&2 != 0) { user, err := doExternalAuth(username, "", pubKey, "", ip, protocol, nil) if err != nil { return user, "", err } return checkUserAndPubKey(&user, pubKey, isSSHCert) } if config.PreLoginHook != "" { user, err := executePreLoginHook(username, SSHLoginMethodPublicKey, ip, protocol, nil) if err != nil { return user, "", err } return checkUserAndPubKey(&user, pubKey, isSSHCert) } return provider.validateUserAndPubKey(username, pubKey, isSSHCert) } // CheckKeyboardInteractiveAuth checks the keyboard interactive authentication and returns // the authenticated user or an error func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string, isPartialAuth bool, ) (User, error) { var user User var err error username = config.convertName(username) if plugin.Handler.HasAuthScope(plugin.AuthScopeKeyboardInteractive) { user, err = doPluginAuth(username, "", nil, ip, protocol, nil, plugin.AuthScopeKeyboardInteractive) } else if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&4 != 0) { user, err = doExternalAuth(username, "", nil, "1", ip, protocol, nil) } else if config.PreLoginHook != "" { user, err = executePreLoginHook(username, SSHLoginMethodKeyboardInteractive, ip, protocol, nil) } else { user, err = provider.userExists(username, "") } if err != nil { return user, err } return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol, isPartialAuth) } // GetFTPPreAuthUser returns the SFTPGo user with the specified username // after receiving the FTP "USER" command. // If a pre-login hook is defined it will be executed so the SFTPGo user // can be created if it does not exist func GetFTPPreAuthUser(username, ip string) (User, error) { var user User var err error if config.PreLoginHook != "" { user, err = executePreLoginHook(username, "", ip, protocolFTP, nil) } else { user, err = UserExists(username, "") } if err != nil { return user, err } err = user.LoadAndApplyGroupSettings() return user, err } // GetUserAfterIDPAuth returns the SFTPGo user with the specified username // after a successful authentication with an external identity provider. // If a pre-login hook is defined it will be executed so the SFTPGo user // can be created if it does not exist func GetUserAfterIDPAuth(username, ip, protocol string, oidcTokenFields *map[string]any) (User, error) { var user User var err error if config.PreLoginHook != "" { user, err = executePreLoginHook(username, LoginMethodIDP, ip, protocol, oidcTokenFields) user.Filters.RequirePasswordChange = false } else { user, err = UserExists(username, "") } if err != nil { return user, err } err = user.LoadAndApplyGroupSettings() return user, err } // GetDefenderHosts returns hosts that are banned or for which some violations have been detected func GetDefenderHosts(from int64, limit int) ([]DefenderEntry, error) { return provider.getDefenderHosts(from, limit) } // GetDefenderHostByIP returns a defender host by ip, if any func GetDefenderHostByIP(ip string, from int64) (DefenderEntry, error) { return provider.getDefenderHostByIP(ip, from) } // IsDefenderHostBanned returns a defender entry and no error if the specified host is banned func IsDefenderHostBanned(ip string) (DefenderEntry, error) { return provider.isDefenderHostBanned(ip) } // UpdateDefenderBanTime increments ban time for the specified ip func UpdateDefenderBanTime(ip string, minutes int) error { return provider.updateDefenderBanTime(ip, minutes) } // DeleteDefenderHost removes the specified IP from the defender lists func DeleteDefenderHost(ip string) error { return provider.deleteDefenderHost(ip) } // AddDefenderEvent adds an event for the given IP with the given score // and returns the host with the updated score func AddDefenderEvent(ip string, score int, from int64) (DefenderEntry, error) { if err := provider.addDefenderEvent(ip, score); err != nil { return DefenderEntry{}, err } return provider.getDefenderHostByIP(ip, from) } // SetDefenderBanTime sets the ban time for the specified IP func SetDefenderBanTime(ip string, banTime int64) error { return provider.setDefenderBanTime(ip, banTime) } // CleanupDefender removes events and hosts older than "from" from the data provider func CleanupDefender(from int64) error { return provider.cleanupDefender(from) } // UpdateShareLastUse updates the LastUseAt and UsedTokens for the given share func UpdateShareLastUse(share *Share, numTokens int) error { return provider.updateShareLastUse(share.ShareID, numTokens) } // UpdateAPIKeyLastUse updates the LastUseAt field for the given API key func UpdateAPIKeyLastUse(apiKey *APIKey) error { lastUse := util.GetTimeFromMsecSinceEpoch(apiKey.LastUseAt) diff := -time.Until(lastUse) if diff < 0 || diff > lastLoginMinDelay { return provider.updateAPIKeyLastUse(apiKey.KeyID) } return nil } // UpdateLastLogin updates the last login field for the given SFTPGo user func UpdateLastLogin(user *User) { delay := lastLoginMinDelay if user.Filters.ExternalAuthCacheTime > 0 { delay = time.Duration(user.Filters.ExternalAuthCacheTime) * time.Second } if user.LastLogin <= user.UpdatedAt || !isLastActivityRecent(user.LastLogin, delay) { err := provider.updateLastLogin(user.Username) if err == nil { webDAVUsersCache.updateLastLogin(user.Username) } } } // UpdateAdminLastLogin updates the last login field for the given SFTPGo admin func UpdateAdminLastLogin(admin *Admin) { if !isLastActivityRecent(admin.LastLogin, lastLoginMinDelay) { provider.updateAdminLastLogin(admin.Username) //nolint:errcheck } } // UpdateUserQuota updates the quota for the given SFTPGo user adding filesAdd and sizeAdd. // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error { if config.TrackQuota == 0 { return util.NewMethodDisabledError(trackQuotaDisabledError) } else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() { return nil } if filesAdd == 0 && sizeAdd == 0 && !reset { return nil } if config.DelayedQuotaUpdate == 0 || reset { if reset { delayedQuotaUpdater.resetUserQuota(user.Username) } return provider.updateQuota(user.Username, filesAdd, sizeAdd, reset) } delayedQuotaUpdater.updateUserQuota(user.Username, filesAdd, sizeAdd) return nil } // UpdateUserFolderQuota updates the quota for the given user and virtual folder. func UpdateUserFolderQuota(folder *vfs.VirtualFolder, user *User, filesAdd int, sizeAdd int64, reset bool) { if folder.IsIncludedInUserQuota() { UpdateUserQuota(user, filesAdd, sizeAdd, reset) //nolint:errcheck return } UpdateVirtualFolderQuota(&folder.BaseVirtualFolder, filesAdd, sizeAdd, reset) //nolint:errcheck } // UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd. // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error { if config.TrackQuota == 0 { return util.NewMethodDisabledError(trackQuotaDisabledError) } if filesAdd == 0 && sizeAdd == 0 && !reset { return nil } if config.DelayedQuotaUpdate == 0 || reset { if reset { delayedQuotaUpdater.resetFolderQuota(vfolder.Name) } return provider.updateFolderQuota(vfolder.Name, filesAdd, sizeAdd, reset) } delayedQuotaUpdater.updateFolderQuota(vfolder.Name, filesAdd, sizeAdd) return nil } // UpdateUserTransferQuota updates the transfer quota for the given SFTPGo user. // If reset is true uploadSize and downloadSize indicates the actual sizes instead of the difference. func UpdateUserTransferQuota(user *User, uploadSize, downloadSize int64, reset bool) error { if config.TrackQuota == 0 { return util.NewMethodDisabledError(trackQuotaDisabledError) } else if config.TrackQuota == 2 && !reset && !user.HasTransferQuotaRestrictions() { return nil } if downloadSize == 0 && uploadSize == 0 && !reset { return nil } if config.DelayedQuotaUpdate == 0 || reset { if reset { delayedQuotaUpdater.resetUserTransferQuota(user.Username) } return provider.updateTransferQuota(user.Username, uploadSize, downloadSize, reset) } delayedQuotaUpdater.updateUserTransferQuota(user.Username, uploadSize, downloadSize) return nil } // UpdateUserTransferTimestamps updates the first download/upload fields if unset func UpdateUserTransferTimestamps(username string, isUpload bool) error { if isUpload { err := provider.setFirstUploadTimestamp(username) if err != nil { providerLog(logger.LevelWarn, "unable to set first upload: %v", err) } return err } err := provider.setFirstDownloadTimestamp(username) if err != nil { providerLog(logger.LevelWarn, "unable to set first download: %v", err) } return err } // GetUsedQuota returns the used quota for the given SFTPGo user. func GetUsedQuota(username string) (int, int64, int64, int64, error) { if config.TrackQuota == 0 { return 0, 0, 0, 0, util.NewMethodDisabledError(trackQuotaDisabledError) } files, size, ulTransferSize, dlTransferSize, err := provider.getUsedQuota(username) if err != nil { return files, size, ulTransferSize, dlTransferSize, err } delayedFiles, delayedSize := delayedQuotaUpdater.getUserPendingQuota(username) delayedUlTransferSize, delayedDLTransferSize := delayedQuotaUpdater.getUserPendingTransferQuota(username) return files + delayedFiles, size + delayedSize, ulTransferSize + delayedUlTransferSize, dlTransferSize + delayedDLTransferSize, err } // GetUsedVirtualFolderQuota returns the used quota for the given virtual folder. func GetUsedVirtualFolderQuota(name string) (int, int64, error) { if config.TrackQuota == 0 { return 0, 0, util.NewMethodDisabledError(trackQuotaDisabledError) } files, size, err := provider.getUsedFolderQuota(name) if err != nil { return files, size, err } delayedFiles, delayedSize := delayedQuotaUpdater.getFolderPendingQuota(name) return files + delayedFiles, size + delayedSize, err } // GetConfigs returns the configurations func GetConfigs() (Configs, error) { return provider.getConfigs() } // UpdateConfigs updates configurations func UpdateConfigs(configs *Configs, executor, ipAddress, role string) error { if configs == nil { configs = &Configs{} } else { configs.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) } err := provider.setConfigs(configs) if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectConfigs, "configs", role, configs) } return err } // AddShare adds a new share func AddShare(share *Share, executor, ipAddress, role string) error { err := provider.addShare(share) if err == nil { executeAction(operationAdd, executor, ipAddress, actionObjectShare, share.ShareID, role, share) } return err } // UpdateShare updates an existing share func UpdateShare(share *Share, executor, ipAddress, role string) error { err := provider.updateShare(share) if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectShare, share.ShareID, role, share) } return err } // DeleteShare deletes an existing share func DeleteShare(shareID string, executor, ipAddress, role string) error { share, err := provider.shareExists(shareID, executor) if err != nil { return err } err = provider.deleteShare(share) if err == nil { executeAction(operationDelete, executor, ipAddress, actionObjectShare, shareID, role, &share) } return err } // ShareExists returns the share with the given ID if it exists func ShareExists(shareID, username string) (Share, error) { if shareID == "" { return Share{}, util.NewRecordNotFoundError(fmt.Sprintf("Share %q does not exist", shareID)) } return provider.shareExists(shareID, username) } // AddIPListEntry adds a new IP list entry func AddIPListEntry(entry *IPListEntry, executor, ipAddress, executorRole string) error { err := provider.addIPListEntry(entry) if err == nil { executeAction(operationAdd, executor, ipAddress, actionObjectIPListEntry, entry.getName(), executorRole, entry) for _, l := range inMemoryLists { l.addEntry(entry) } } return err } // UpdateIPListEntry updates an existing IP list entry func UpdateIPListEntry(entry *IPListEntry, executor, ipAddress, executorRole string) error { err := provider.updateIPListEntry(entry) if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectIPListEntry, entry.getName(), executorRole, entry) for _, l := range inMemoryLists { l.updateEntry(entry) } } return err } // DeleteIPListEntry deletes an existing IP list entry func DeleteIPListEntry(ipOrNet string, listType IPListType, executor, ipAddress, executorRole string) error { entry, err := provider.ipListEntryExists(ipOrNet, listType) if err != nil { return err } err = provider.deleteIPListEntry(entry, config.IsShared == 1) if err == nil { executeAction(operationDelete, executor, ipAddress, actionObjectIPListEntry, entry.getName(), executorRole, &entry) for _, l := range inMemoryLists { l.removeEntry(&entry) } } return err } // IPListEntryExists returns the IP list entry with the given IP/net and type if it exists func IPListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { return provider.ipListEntryExists(ipOrNet, listType) } // GetIPListEntries returns the IP list entries applying the specified criteria and search limit func GetIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { if !slices.Contains(supportedIPListType, listType) { return nil, util.NewValidationError(fmt.Sprintf("invalid list type %d", listType)) } return provider.getIPListEntries(listType, filter, from, order, limit) } // AddRole adds a new role func AddRole(role *Role, executor, ipAddress, executorRole string) error { role.Name = config.convertName(role.Name) err := provider.addRole(role) if err == nil { executeAction(operationAdd, executor, ipAddress, actionObjectRole, role.Name, executorRole, role) } return err } // UpdateRole updates an existing Role func UpdateRole(role *Role, executor, ipAddress, executorRole string) error { err := provider.updateRole(role) if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectRole, role.Name, executorRole, role) } return err } // DeleteRole deletes an existing Role func DeleteRole(name string, executor, ipAddress, executorRole string) error { name = config.convertName(name) role, err := provider.roleExists(name) if err != nil { return err } if len(role.Admins) > 0 { errorString := fmt.Sprintf("the role %q is referenced, it cannot be removed", role.Name) return util.NewValidationError(errorString) } err = provider.deleteRole(role) if err == nil { executeAction(operationDelete, executor, ipAddress, actionObjectRole, role.Name, executorRole, &role) for _, user := range role.Users { provider.setUpdatedAt(user) u, err := provider.userExists(user, "") if err == nil { webDAVUsersCache.swap(&u, "") executeAction(operationUpdate, executor, ipAddress, actionObjectUser, u.Username, u.Role, &u) } } } return err } // RoleExists returns the Role with the given name if it exists func RoleExists(name string) (Role, error) { name = config.convertName(name) return provider.roleExists(name) } // AddGroup adds a new group func AddGroup(group *Group, executor, ipAddress, role string) error { group.Name = config.convertName(group.Name) err := provider.addGroup(group) if err == nil { executeAction(operationAdd, executor, ipAddress, actionObjectGroup, group.Name, role, group) } return err } // UpdateGroup updates an existing Group func UpdateGroup(group *Group, users []string, executor, ipAddress, role string) error { err := provider.updateGroup(group) if err == nil { for _, user := range users { provider.setUpdatedAt(user) u, err := provider.userExists(user, "") if err == nil { webDAVUsersCache.swap(&u, "") } else { RemoveCachedWebDAVUser(user) } } executeAction(operationUpdate, executor, ipAddress, actionObjectGroup, group.Name, role, group) } return err } // DeleteGroup deletes an existing Group func DeleteGroup(name string, executor, ipAddress, role string) error { name = config.convertName(name) group, err := provider.groupExists(name) if err != nil { return err } if len(group.Users) > 0 { errorString := fmt.Sprintf("the group %q is referenced, it cannot be removed", group.Name) return util.NewValidationError(errorString) } err = provider.deleteGroup(group) if err == nil { for _, user := range group.Users { provider.setUpdatedAt(user) u, err := provider.userExists(user, "") if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectUser, u.Username, u.Role, &u) } RemoveCachedWebDAVUser(user) } executeAction(operationDelete, executor, ipAddress, actionObjectGroup, group.Name, role, &group) } return err } // GroupExists returns the Group with the given name if it exists func GroupExists(name string) (Group, error) { name = config.convertName(name) return provider.groupExists(name) } // AddAPIKey adds a new API key func AddAPIKey(apiKey *APIKey, executor, ipAddress, role string) error { err := provider.addAPIKey(apiKey) if err == nil { executeAction(operationAdd, executor, ipAddress, actionObjectAPIKey, apiKey.KeyID, role, apiKey) } return err } // UpdateAPIKey updates an existing API key func UpdateAPIKey(apiKey *APIKey, executor, ipAddress, role string) error { err := provider.updateAPIKey(apiKey) if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectAPIKey, apiKey.KeyID, role, apiKey) } return err } // DeleteAPIKey deletes an existing API key func DeleteAPIKey(keyID string, executor, ipAddress, role string) error { apiKey, err := provider.apiKeyExists(keyID) if err != nil { return err } err = provider.deleteAPIKey(apiKey) if err == nil { executeAction(operationDelete, executor, ipAddress, actionObjectAPIKey, apiKey.KeyID, role, &apiKey) cachedAPIKeys.Remove(keyID) } return err } // APIKeyExists returns the API key with the given ID if it exists func APIKeyExists(keyID string) (APIKey, error) { if keyID == "" { return APIKey{}, util.NewRecordNotFoundError(fmt.Sprintf("API key %q does not exist", keyID)) } return provider.apiKeyExists(keyID) } // GetEventActions returns an array of event actions respecting limit and offset func GetEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) { return provider.getEventActions(limit, offset, order, minimal) } // EventActionExists returns the event action with the given name if it exists func EventActionExists(name string) (BaseEventAction, error) { name = config.convertName(name) return provider.eventActionExists(name) } // AddEventAction adds a new event action func AddEventAction(action *BaseEventAction, executor, ipAddress, role string) error { action.Name = config.convertName(action.Name) err := provider.addEventAction(action) if err == nil { executeAction(operationAdd, executor, ipAddress, actionObjectEventAction, action.Name, role, action) } return err } // UpdateEventAction updates an existing event action func UpdateEventAction(action *BaseEventAction, executor, ipAddress, role string) error { err := provider.updateEventAction(action) if err == nil { if fnReloadRules != nil { fnReloadRules() } executeAction(operationUpdate, executor, ipAddress, actionObjectEventAction, action.Name, role, action) } return err } // DeleteEventAction deletes an existing event action func DeleteEventAction(name string, executor, ipAddress, role string) error { name = config.convertName(name) action, err := provider.eventActionExists(name) if err != nil { return err } if len(action.Rules) > 0 { errorString := fmt.Sprintf("the event action %#q is referenced, it cannot be removed", action.Name) return util.NewValidationError(errorString) } err = provider.deleteEventAction(action) if err == nil { executeAction(operationDelete, executor, ipAddress, actionObjectEventAction, action.Name, role, &action) } return err } // GetEventRules returns an array of event rules respecting limit and offset func GetEventRules(limit, offset int, order string) ([]EventRule, error) { return provider.getEventRules(limit, offset, order) } // GetRecentlyUpdatedRules returns the event rules updated after the specified time func GetRecentlyUpdatedRules(after int64) ([]EventRule, error) { return provider.getRecentlyUpdatedRules(after) } // EventRuleExists returns the event rule with the given name if it exists func EventRuleExists(name string) (EventRule, error) { name = config.convertName(name) return provider.eventRuleExists(name) } // AddEventRule adds a new event rule func AddEventRule(rule *EventRule, executor, ipAddress, role string) error { rule.Name = config.convertName(rule.Name) err := provider.addEventRule(rule) if err == nil { if fnReloadRules != nil { fnReloadRules() } executeAction(operationAdd, executor, ipAddress, actionObjectEventRule, rule.Name, role, rule) } return err } // UpdateEventRule updates an existing event rule func UpdateEventRule(rule *EventRule, executor, ipAddress, role string) error { err := provider.updateEventRule(rule) if err == nil { if fnReloadRules != nil { fnReloadRules() } executeAction(operationUpdate, executor, ipAddress, actionObjectEventRule, rule.Name, role, rule) } return err } // DeleteEventRule deletes an existing event rule func DeleteEventRule(name string, executor, ipAddress, role string) error { name = config.convertName(name) rule, err := provider.eventRuleExists(name) if err != nil { return err } err = provider.deleteEventRule(rule, config.IsShared == 1) if err == nil { if fnRemoveRule != nil { fnRemoveRule(rule.Name) } executeAction(operationDelete, executor, ipAddress, actionObjectEventRule, rule.Name, role, &rule) } return err } // RemoveEventRule delets an existing event rule without marking it as deleted func RemoveEventRule(rule EventRule) error { return provider.deleteEventRule(rule, false) } // GetTaskByName returns the task with the specified name func GetTaskByName(name string) (Task, error) { return provider.getTaskByName(name) } // AddTask add a task with the specified name func AddTask(name string) error { return provider.addTask(name) } // UpdateTask updates the task with the specified name and version func UpdateTask(name string, version int64) error { return provider.updateTask(name, version) } // UpdateTaskTimestamp updates the timestamp for the task with the specified name func UpdateTaskTimestamp(name string) error { return provider.updateTaskTimestamp(name) } // GetNodes returns the other cluster nodes func GetNodes() ([]Node, error) { if currentNode == nil { return nil, nil } nodes, err := provider.getNodes() if err != nil { providerLog(logger.LevelError, "unable to get other cluster nodes %v", err) } return nodes, err } // GetNodeByName returns a node, different from the current one, by name func GetNodeByName(name string) (Node, error) { if currentNode == nil { return Node{}, util.NewRecordNotFoundError(errNoClusterNodes.Error()) } if name == currentNode.Name { return Node{}, util.NewValidationError(fmt.Sprintf("%s is the current node, it must refer to other nodes", name)) } return provider.getNodeByName(name) } // HasAdmin returns true if the first admin has been created // and so SFTPGo is ready to be used func HasAdmin() bool { return isAdminCreated.Load() } // AddAdmin adds a new SFTPGo admin func AddAdmin(admin *Admin, executor, ipAddress, role string) error { admin.Filters.RecoveryCodes = nil admin.Filters.TOTPConfig = AdminTOTPConfig{ Enabled: false, } admin.Username = config.convertName(admin.Username) err := provider.addAdmin(admin) if err == nil { isAdminCreated.Store(true) executeAction(operationAdd, executor, ipAddress, actionObjectAdmin, admin.Username, role, admin) } return err } // UpdateAdmin updates an existing SFTPGo admin func UpdateAdmin(admin *Admin, executor, ipAddress, role string) error { err := provider.updateAdmin(admin) if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectAdmin, admin.Username, role, admin) } return err } // DeleteAdmin deletes an existing SFTPGo admin func DeleteAdmin(username, executor, ipAddress, role string) error { username = config.convertName(username) admin, err := provider.adminExists(username) if err != nil { return err } err = provider.deleteAdmin(admin) if err == nil { executeAction(operationDelete, executor, ipAddress, actionObjectAdmin, admin.Username, role, &admin) cachedAdminPasswords.Remove(username) } return err } // AdminExists returns the admin with the given username if it exists func AdminExists(username string) (Admin, error) { username = config.convertName(username) return provider.adminExists(username) } // UserExists checks if the given SFTPGo username exists, returns an error if no match is found func UserExists(username, role string) (User, error) { username = config.convertName(username) return provider.userExists(username, role) } // GetAdminSignature returns the signature for the admin with the specified // username. func GetAdminSignature(username string) (string, error) { username = config.convertName(username) return provider.getAdminSignature(username) } // GetUserSignature returns the signature for the user with the specified // username. func GetUserSignature(username string) (string, error) { username = config.convertName(username) return provider.getUserSignature(username) } // GetUserWithGroupSettings tries to return the user with the specified username // loading also the group settings func GetUserWithGroupSettings(username, role string) (User, error) { username = config.convertName(username) user, err := provider.userExists(username, role) if err != nil { return user, err } err = user.LoadAndApplyGroupSettings() return user, err } // GetUserVariants tries to return the user with the specified username with and without // group settings applied func GetUserVariants(username, role string) (User, User, error) { username = config.convertName(username) user, err := provider.userExists(username, role) if err != nil { return user, User{}, err } userWithGroupSettings := user.getACopy() err = userWithGroupSettings.LoadAndApplyGroupSettings() return user, userWithGroupSettings, err } // AddUser adds a new SFTPGo user. func AddUser(user *User, executor, ipAddress, role string) error { user.Username = config.convertName(user.Username) err := provider.addUser(user) if err == nil { executeAction(operationAdd, executor, ipAddress, actionObjectUser, user.Username, role, user) } return err } // UpdateUserPassword updates the user password func UpdateUserPassword(username, plainPwd, executor, ipAddress, role string) error { user, err := provider.userExists(username, role) if err != nil { return err } userCopy := user.getACopy() userCopy.Password = plainPwd if err := createUserPasswordHash(&userCopy); err != nil { return err } user.LastPasswordChange = userCopy.LastPasswordChange user.Password = userCopy.Password user.Filters.RequirePasswordChange = false // the last password change is set when validating the user if err := provider.updateUser(&user); err != nil { return err } webDAVUsersCache.swap(&user, plainPwd) executeAction(operationUpdate, executor, ipAddress, actionObjectUser, username, role, &user) return nil } // UpdateUser updates an existing SFTPGo user. func UpdateUser(user *User, executor, ipAddress, role string) error { if user.groupSettingsApplied { return errors.New("cannot save a user with group settings applied") } err := provider.updateUser(user) if err == nil { webDAVUsersCache.swap(user, "") executeAction(operationUpdate, executor, ipAddress, actionObjectUser, user.Username, role, user) } return err } // DeleteUser deletes an existing SFTPGo user. func DeleteUser(username, executor, ipAddress, role string) error { username = config.convertName(username) user, err := provider.userExists(username, role) if err != nil { return err } err = provider.deleteUser(user, config.IsShared == 1) if err == nil { RemoveCachedWebDAVUser(user.Username) delayedQuotaUpdater.resetUserQuota(user.Username) cachedUserPasswords.Remove(username) executeAction(operationDelete, executor, ipAddress, actionObjectUser, user.Username, role, &user) } return err } // AddActiveTransfer stores the specified transfer func AddActiveTransfer(transfer ActiveTransfer) { if err := provider.addActiveTransfer(transfer); err != nil { providerLog(logger.LevelError, "unable to add transfer id %v, connection id %v: %v", transfer.ID, transfer.ConnID, err) } } // UpdateActiveTransferSizes updates the current upload and download sizes for the specified transfer func UpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) { if err := provider.updateActiveTransferSizes(ulSize, dlSize, transferID, connectionID); err != nil { providerLog(logger.LevelError, "unable to update sizes for transfer id %v, connection id %v: %v", transferID, connectionID, err) } } // RemoveActiveTransfer removes the specified transfer func RemoveActiveTransfer(transferID int64, connectionID string) { if err := provider.removeActiveTransfer(transferID, connectionID); err != nil { providerLog(logger.LevelError, "unable to delete transfer id %v, connection id %v: %v", transferID, connectionID, err) } } // CleanupActiveTransfers removes the transfer before the specified time func CleanupActiveTransfers(before time.Time) error { err := provider.cleanupActiveTransfers(before) if err == nil { providerLog(logger.LevelDebug, "deleted active transfers updated before: %v", before) } else { providerLog(logger.LevelError, "error deleting active transfers updated before %v: %v", before, err) } return err } // GetActiveTransfers retrieves the active transfers with an update time after the specified value func GetActiveTransfers(from time.Time) ([]ActiveTransfer, error) { return provider.getActiveTransfers(from) } // AddSharedSession stores a new session within the data provider func AddSharedSession(session Session) error { err := provider.addSharedSession(session) if err != nil { providerLog(logger.LevelError, "unable to add shared session, key %q, type: %v, err: %v", session.Key, session.Type, err) } return err } // DeleteSharedSession deletes the session with the specified key func DeleteSharedSession(key string, sessionType SessionType) error { err := provider.deleteSharedSession(key, sessionType) if err != nil { providerLog(logger.LevelError, "unable to add shared session, key %q, err: %v", key, err) } return err } // GetSharedSession retrieves the session with the specified key func GetSharedSession(key string, sessionType SessionType) (Session, error) { return provider.getSharedSession(key, sessionType) } // CleanupSharedSessions removes the shared session with the specified type and // before the specified time func CleanupSharedSessions(sessionType SessionType, before time.Time) error { err := provider.cleanupSharedSessions(sessionType, util.GetTimeAsMsSinceEpoch(before)) if err == nil { providerLog(logger.LevelDebug, "deleted shared sessions before: %v, type: %v", before, sessionType) } else { providerLog(logger.LevelError, "error deleting shared session before %v, type %v: %v", before, sessionType, err) } return err } // ReloadConfig reloads provider configuration. // Currently only implemented for memory provider, allows to reload the users // from the configured file, if defined func ReloadConfig() error { return provider.reloadConfig() } // GetShares returns an array of shares respecting limit and offset func GetShares(limit, offset int, order, username string) ([]Share, error) { return provider.getShares(limit, offset, order, username) } // GetAPIKeys returns an array of API keys respecting limit and offset func GetAPIKeys(limit, offset int, order string) ([]APIKey, error) { return provider.getAPIKeys(limit, offset, order) } // GetAdmins returns an array of admins respecting limit and offset func GetAdmins(limit, offset int, order string) ([]Admin, error) { return provider.getAdmins(limit, offset, order) } // GetRoles returns an array of roles respecting limit and offset func GetRoles(limit, offset int, order string, minimal bool) ([]Role, error) { return provider.getRoles(limit, offset, order, minimal) } // GetGroups returns an array of groups respecting limit and offset func GetGroups(limit, offset int, order string, minimal bool) ([]Group, error) { return provider.getGroups(limit, offset, order, minimal) } // GetUsers returns an array of users respecting limit and offset func GetUsers(limit, offset int, order, role string) ([]User, error) { return provider.getUsers(limit, offset, order, role) } // GetUsersForQuotaCheck returns the users with the fields required for a quota check func GetUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { return provider.getUsersForQuotaCheck(toFetch) } // AddFolder adds a new virtual folder. func AddFolder(folder *vfs.BaseVirtualFolder, executor, ipAddress, role string) error { folder.Name = config.convertName(folder.Name) err := provider.addFolder(folder) if err == nil { executeAction(operationAdd, executor, ipAddress, actionObjectFolder, folder.Name, role, &wrappedFolder{Folder: *folder}) } return err } // UpdateFolder updates the specified virtual folder func UpdateFolder(folder *vfs.BaseVirtualFolder, users []string, groups []string, executor, ipAddress, role string) error { err := provider.updateFolder(folder) if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectFolder, folder.Name, role, &wrappedFolder{Folder: *folder}) usersInGroups, errGrp := provider.getUsersInGroups(groups) if errGrp == nil { users = append(users, usersInGroups...) users = util.RemoveDuplicates(users, false) } else { providerLog(logger.LevelWarn, "unable to get users in groups %+v: %v", groups, errGrp) } for _, user := range users { provider.setUpdatedAt(user) u, err := provider.userExists(user, "") if err == nil { webDAVUsersCache.swap(&u, "") executeAction(operationUpdate, executor, ipAddress, actionObjectUser, u.Username, u.Role, &u) } else { RemoveCachedWebDAVUser(user) } } } return err } // DeleteFolder deletes an existing folder. func DeleteFolder(folderName, executor, ipAddress, role string) error { folderName = config.convertName(folderName) folder, err := provider.getFolderByName(folderName) if err != nil { return err } err = provider.deleteFolder(folder) if err == nil { executeAction(operationDelete, executor, ipAddress, actionObjectFolder, folder.Name, role, &wrappedFolder{Folder: folder}) users := folder.Users usersInGroups, errGrp := provider.getUsersInGroups(folder.Groups) if errGrp == nil { users = append(users, usersInGroups...) users = util.RemoveDuplicates(users, false) } else { providerLog(logger.LevelWarn, "unable to get users in groups %+v: %v", folder.Groups, errGrp) } for _, user := range users { provider.setUpdatedAt(user) u, err := provider.userExists(user, "") if err == nil { executeAction(operationUpdate, executor, ipAddress, actionObjectUser, u.Username, u.Role, &u) } RemoveCachedWebDAVUser(user) } delayedQuotaUpdater.resetFolderQuota(folderName) } return err } // GetFolderByName returns the folder with the specified name if any func GetFolderByName(name string) (vfs.BaseVirtualFolder, error) { name = config.convertName(name) return provider.getFolderByName(name) } // GetFolders returns an array of folders respecting limit and offset func GetFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) { return provider.getFolders(limit, offset, order, minimal) } func dumpUsers(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeUsers) { users, err := provider.dumpUsers() if err != nil { return err } data.Users = users } return nil } func dumpFolders(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeFolders) { folders, err := provider.dumpFolders() if err != nil { return err } data.Folders = folders } return nil } func dumpGroups(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeGroups) { groups, err := provider.dumpGroups() if err != nil { return err } data.Groups = groups } return nil } func dumpAdmins(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeAdmins) { admins, err := provider.dumpAdmins() if err != nil { return err } data.Admins = admins } return nil } func dumpAPIKeys(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeAPIKeys) { apiKeys, err := provider.dumpAPIKeys() if err != nil { return err } data.APIKeys = apiKeys } return nil } func dumpShares(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeShares) { shares, err := provider.dumpShares() if err != nil { return err } data.Shares = shares } return nil } func dumpActions(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeActions) { actions, err := provider.dumpEventActions() if err != nil { return err } data.EventActions = actions } return nil } func dumpRules(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeRules) { rules, err := provider.dumpEventRules() if err != nil { return err } data.EventRules = rules } return nil } func dumpRoles(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeRoles) { roles, err := provider.dumpRoles() if err != nil { return err } data.Roles = roles } return nil } func dumpIPLists(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeIPLists) { ipLists, err := provider.dumpIPListEntries() if err != nil { return err } data.IPLists = ipLists } return nil } func dumpConfigs(data *BackupData, scopes []string) error { if len(scopes) == 0 || slices.Contains(scopes, DumpScopeConfigs) { configs, err := provider.getConfigs() if err != nil { return err } data.Configs = &configs } return nil } // DumpData returns a dump containing the requested scopes. // Empty scopes means all func DumpData(scopes []string) (BackupData, error) { data := BackupData{ Version: DumpVersion, } if err := dumpGroups(&data, scopes); err != nil { return data, err } if err := dumpUsers(&data, scopes); err != nil { return data, err } if err := dumpFolders(&data, scopes); err != nil { return data, err } if err := dumpAdmins(&data, scopes); err != nil { return data, err } if err := dumpAPIKeys(&data, scopes); err != nil { return data, err } if err := dumpShares(&data, scopes); err != nil { return data, err } if err := dumpActions(&data, scopes); err != nil { return data, err } if err := dumpRules(&data, scopes); err != nil { return data, err } if err := dumpRoles(&data, scopes); err != nil { return data, err } if err := dumpIPLists(&data, scopes); err != nil { return data, err } if err := dumpConfigs(&data, scopes); err != nil { return data, err } return data, nil } // ParseDumpData tries to parse data as BackupData func ParseDumpData(data []byte) (BackupData, error) { var dump BackupData err := json.Unmarshal(data, &dump) if err != nil { return dump, err } if dump.Version < 17 { providerLog(logger.LevelInfo, "updating placeholders for actions restored from dump version %d", dump.Version) eventActions, err := updateEventActionPlaceholders(dump.EventActions) if err != nil { return dump, fmt.Errorf("unable to update event action placeholders for dump version %d: %w", dump.Version, err) } dump.EventActions = eventActions } return dump, err } // GetProviderConfig returns the current provider configuration func GetProviderConfig() Config { return config } // GetProviderStatus returns an error if the provider is not available func GetProviderStatus() ProviderStatus { err := provider.checkAvailability() status := ProviderStatus{ Driver: config.Driver, } if err == nil { status.IsActive = true } else { status.IsActive = false status.Error = err.Error() } return status } // Close releases all provider resources. // This method is used in test cases. // Closing an uninitialized provider is not supported func Close() error { stopScheduler() return provider.close() } func createProvider(basePath string) error { sqlPlaceholders = getSQLPlaceholders() if err := validateSQLTablesPrefix(); err != nil { return err } logSender = fmt.Sprintf("dataprovider_%v", config.Driver) switch config.Driver { case SQLiteDataProviderName: return initializeSQLiteProvider(basePath) case PGSQLDataProviderName, CockroachDataProviderName: return initializePGSQLProvider() case MySQLDataProviderName: return initializeMySQLProvider() case BoltDataProviderName: return initializeBoltProvider(basePath) case MemoryDataProviderName: if err := initializeMemoryProvider(basePath); err != nil { logger.Warn(logSender, "", "provider initialized but data loading failed: %v", err) logger.WarnToConsole("provider initialized but data loading failed: %v", err) } return nil default: return fmt.Errorf("unsupported data provider: %v", config.Driver) } } func copyBaseUserFilters(in sdk.BaseUserFilters) sdk.BaseUserFilters { filters := sdk.BaseUserFilters{} filters.MaxUploadFileSize = in.MaxUploadFileSize filters.TLSUsername = in.TLSUsername filters.UserType = in.UserType filters.AllowedIP = make([]string, len(in.AllowedIP)) copy(filters.AllowedIP, in.AllowedIP) filters.DeniedIP = make([]string, len(in.DeniedIP)) copy(filters.DeniedIP, in.DeniedIP) filters.DeniedLoginMethods = make([]string, len(in.DeniedLoginMethods)) copy(filters.DeniedLoginMethods, in.DeniedLoginMethods) filters.FilePatterns = make([]sdk.PatternsFilter, len(in.FilePatterns)) copy(filters.FilePatterns, in.FilePatterns) filters.DeniedProtocols = make([]string, len(in.DeniedProtocols)) copy(filters.DeniedProtocols, in.DeniedProtocols) filters.TwoFactorAuthProtocols = make([]string, len(in.TwoFactorAuthProtocols)) copy(filters.TwoFactorAuthProtocols, in.TwoFactorAuthProtocols) filters.Hooks.ExternalAuthDisabled = in.Hooks.ExternalAuthDisabled filters.Hooks.PreLoginDisabled = in.Hooks.PreLoginDisabled filters.Hooks.CheckPasswordDisabled = in.Hooks.CheckPasswordDisabled filters.DisableFsChecks = in.DisableFsChecks filters.StartDirectory = in.StartDirectory filters.FTPSecurity = in.FTPSecurity filters.IsAnonymous = in.IsAnonymous filters.AllowAPIKeyAuth = in.AllowAPIKeyAuth filters.ExternalAuthCacheTime = in.ExternalAuthCacheTime filters.DefaultSharesExpiration = in.DefaultSharesExpiration filters.MaxSharesExpiration = in.MaxSharesExpiration filters.PasswordExpiration = in.PasswordExpiration filters.PasswordStrength = in.PasswordStrength filters.WebClient = make([]string, len(in.WebClient)) copy(filters.WebClient, in.WebClient) filters.TLSCerts = make([]string, len(in.TLSCerts)) copy(filters.TLSCerts, in.TLSCerts) filters.BandwidthLimits = make([]sdk.BandwidthLimit, 0, len(in.BandwidthLimits)) for _, limit := range in.BandwidthLimits { bwLimit := sdk.BandwidthLimit{ UploadBandwidth: limit.UploadBandwidth, DownloadBandwidth: limit.DownloadBandwidth, Sources: make([]string, 0, len(limit.Sources)), } bwLimit.Sources = make([]string, len(limit.Sources)) copy(bwLimit.Sources, limit.Sources) filters.BandwidthLimits = append(filters.BandwidthLimits, bwLimit) } filters.AccessTime = make([]sdk.TimePeriod, 0, len(in.AccessTime)) for _, period := range in.AccessTime { filters.AccessTime = append(filters.AccessTime, sdk.TimePeriod{ DayOfWeek: period.DayOfWeek, From: period.From, To: period.To, }) } return filters } func buildUserHomeDir(user *User) { if user.HomeDir == "" { if config.UsersBaseDir != "" { user.HomeDir = filepath.Join(config.UsersBaseDir, user.Username) return } switch user.FsConfig.Provider { case sdk.SFTPFilesystemProvider, sdk.S3FilesystemProvider, sdk.AzureBlobFilesystemProvider, sdk.GCSFilesystemProvider, sdk.HTTPFilesystemProvider: if tempPath != "" { user.HomeDir = filepath.Join(tempPath, user.Username) } else { user.HomeDir = filepath.Join(os.TempDir(), user.Username) } } } else { user.HomeDir = filepath.Clean(user.HomeDir) } } func validateFolderQuotaLimits(folder vfs.VirtualFolder) error { if folder.QuotaSize < -1 { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid quota_size: %v folder path %q", folder.QuotaSize, folder.MappedPath)), util.I18nErrorFolderQuotaSizeInvalid, ) } if folder.QuotaFiles < -1 { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid quota_file: %v folder path %q", folder.QuotaFiles, folder.MappedPath)), util.I18nErrorFolderQuotaFileInvalid, ) } if (folder.QuotaSize == -1 && folder.QuotaFiles != -1) || (folder.QuotaFiles == -1 && folder.QuotaSize != -1) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("virtual folder quota_size and quota_files must be both -1 or >= 0, quota_size: %v quota_files: %v", folder.QuotaFiles, folder.QuotaSize)), util.I18nErrorFolderQuotaInvalid, ) } return nil } func validateUserGroups(user *User) error { if len(user.Groups) == 0 { return nil } hasPrimary := false groupNames := make(map[string]bool) for _, g := range user.Groups { if g.Type < sdk.GroupTypePrimary || g.Type > sdk.GroupTypeMembership { return util.NewValidationError(fmt.Sprintf("invalid group type: %v", g.Type)) } if g.Type == sdk.GroupTypePrimary { if hasPrimary { return util.NewI18nError( util.NewValidationError("only one primary group is allowed"), util.I18nErrorPrimaryGroup, ) } hasPrimary = true } if groupNames[g.Name] { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("the group %q is duplicated", g.Name)), util.I18nErrorDuplicateGroup, ) } groupNames[g.Name] = true } return nil } func validateAssociatedVirtualFolders(vfolders []vfs.VirtualFolder) ([]vfs.VirtualFolder, error) { if len(vfolders) == 0 { return []vfs.VirtualFolder{}, nil } var virtualFolders []vfs.VirtualFolder folderNames := make(map[string]bool) for _, v := range vfolders { v.Name = config.convertName(v.Name) if v.VirtualPath == "" { return nil, util.NewI18nError( util.NewValidationError("mount/virtual path is mandatory"), util.I18nErrorFolderMountPathRequired, ) } cleanedVPath := util.CleanPath(v.VirtualPath) if err := validateFolderQuotaLimits(v); err != nil { return nil, err } if v.Name == "" { return nil, util.NewI18nError(util.NewValidationError("folder name is mandatory"), util.I18nErrorFolderNameRequired) } if folderNames[v.Name] { return nil, util.NewI18nError( util.NewValidationError(fmt.Sprintf("the folder %q is duplicated", v.Name)), util.I18nErrorDuplicatedFolders, ) } for _, vFolder := range virtualFolders { if util.IsDirOverlapped(vFolder.VirtualPath, cleanedVPath, false, "/") { return nil, util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid virtual folder %q, it overlaps with virtual folder %q", v.VirtualPath, vFolder.VirtualPath)), util.I18nErrorOverlappedFolders, ) } } virtualFolders = append(virtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: v.Name, }, VirtualPath: cleanedVPath, QuotaSize: v.QuotaSize, QuotaFiles: v.QuotaFiles, }) folderNames[v.Name] = true } return virtualFolders, nil } func validateUserTOTPConfig(c *UserTOTPConfig, username string) error { if !c.Enabled { c.ConfigName = "" c.Secret = kms.NewEmptySecret() c.Protocols = nil return nil } if c.ConfigName == "" { return util.NewValidationError("totp: config name is mandatory") } if !slices.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) { return util.NewValidationError(fmt.Sprintf("totp: config name %q not found", c.ConfigName)) } if c.Secret.IsEmpty() { return util.NewValidationError("totp: secret is mandatory") } if c.Secret.IsPlain() { c.Secret.SetAdditionalData(username) if err := c.Secret.Encrypt(); err != nil { return util.NewValidationError(fmt.Sprintf("totp: unable to encrypt secret: %v", err)) } } if len(c.Protocols) == 0 { return util.NewValidationError("totp: specify at least one protocol") } for _, protocol := range c.Protocols { if !slices.Contains(MFAProtocols, protocol) { return util.NewValidationError(fmt.Sprintf("totp: invalid protocol %q", protocol)) } } return nil } func validateUserRecoveryCodes(user *User) error { for i := 0; i < len(user.Filters.RecoveryCodes); i++ { code := &user.Filters.RecoveryCodes[i] if code.Secret.IsEmpty() { return util.NewValidationError("mfa: recovery code cannot be empty") } if code.Secret.IsPlain() { code.Secret.SetAdditionalData(user.Username) if err := code.Secret.Encrypt(); err != nil { return util.NewValidationError(fmt.Sprintf("mfa: unable to encrypt recovery code: %v", err)) } } } return nil } func validateUserPermissions(permsToCheck map[string][]string) (map[string][]string, error) { permissions := make(map[string][]string) for dir, perms := range permsToCheck { if len(perms) == 0 && dir == "/" { return permissions, util.NewValidationError(fmt.Sprintf("no permissions granted for the directory: %q", dir)) } if len(perms) > len(ValidPerms) { return permissions, util.NewValidationError("invalid permissions") } for _, p := range perms { if !slices.Contains(ValidPerms, p) { return permissions, util.NewValidationError(fmt.Sprintf("invalid permission: %q", p)) } } cleanedDir := filepath.ToSlash(path.Clean(dir)) if cleanedDir != "/" { cleanedDir = strings.TrimSuffix(cleanedDir, "/") } if !path.IsAbs(cleanedDir) { return permissions, util.NewValidationError(fmt.Sprintf("cannot set permissions for non absolute path: %q", dir)) } if dir != cleanedDir && cleanedDir == "/" { return permissions, util.NewValidationError(fmt.Sprintf("cannot set permissions for invalid subdirectory: %q is an alias for \"/\"", dir)) } if slices.Contains(perms, PermAny) { permissions[cleanedDir] = []string{PermAny} } else { permissions[cleanedDir] = util.RemoveDuplicates(perms, false) } } return permissions, nil } func validatePermissions(user *User) error { if len(user.Permissions) == 0 { return util.NewI18nError(util.NewValidationError("please grant some permissions to this user"), util.I18nErrorNoPermission) } if _, ok := user.Permissions["/"]; !ok { return util.NewI18nError(util.NewValidationError("permissions for the root dir \"/\" must be set"), util.I18nErrorNoRootPermission) } permissions, err := validateUserPermissions(user.Permissions) if err != nil { return util.NewI18nError(err, util.I18nErrorGenericPermission) } user.Permissions = permissions return nil } func validatePublicKeys(user *User) error { if len(user.PublicKeys) == 0 { user.PublicKeys = []string{} } var validatedKeys []string for idx, key := range user.PublicKeys { if key == "" { continue } out, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key)) if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("error parsing public key at position %d: %v", idx, err)), util.I18nErrorPubKeyInvalid, ) } if out.Type() == ssh.InsecureKeyAlgoDSA { //nolint:staticcheck providerLog(logger.LevelError, "dsa public key not accepted, position: %d", idx) return util.NewI18nError( util.NewValidationError(fmt.Sprintf("DSA key format is insecure and it is not allowed for key at position %d", idx)), util.I18nErrorKeyInsecure, ) } if k, ok := out.(ssh.CryptoPublicKey); ok { cryptoKey := k.CryptoPublicKey() if rsaKey, ok := cryptoKey.(*rsa.PublicKey); ok { if size := rsaKey.N.BitLen(); size < 2048 { providerLog(logger.LevelError, "rsa key with size %d at position %d not accepted, minimum 2048", size, idx) return util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid size %d for rsa key at position %d, minimum 2048", size, idx)), util.I18nErrorKeySizeInvalid, ) } } } validatedKeys = append(validatedKeys, key) } user.PublicKeys = util.RemoveDuplicates(validatedKeys, false) return nil } func validateFiltersPatternExtensions(baseFilters *sdk.BaseUserFilters) error { if len(baseFilters.FilePatterns) == 0 { baseFilters.FilePatterns = []sdk.PatternsFilter{} return nil } filteredPaths := []string{} var filters []sdk.PatternsFilter for _, f := range baseFilters.FilePatterns { cleanedPath := filepath.ToSlash(path.Clean(f.Path)) if !path.IsAbs(cleanedPath) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid path %q for file patterns filter", f.Path)), util.I18nErrorFilePatternPathInvalid, ) } if slices.Contains(filteredPaths, cleanedPath) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("duplicate file patterns filter for path %q", f.Path)), util.I18nErrorFilePatternDuplicated, ) } if len(f.AllowedPatterns) == 0 && len(f.DeniedPatterns) == 0 { return util.NewValidationError(fmt.Sprintf("empty file patterns filter for path %q", f.Path)) } if f.DenyPolicy < sdk.DenyPolicyDefault || f.DenyPolicy > sdk.DenyPolicyHide { return util.NewValidationError(fmt.Sprintf("invalid deny policy %v for path %q", f.DenyPolicy, f.Path)) } f.Path = cleanedPath allowed := make([]string, 0, len(f.AllowedPatterns)) denied := make([]string, 0, len(f.DeniedPatterns)) for _, pattern := range f.AllowedPatterns { _, err := path.Match(pattern, "abc") if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid file pattern filter %q", pattern)), util.I18nErrorFilePatternInvalid, ) } allowed = append(allowed, strings.ToLower(pattern)) } for _, pattern := range f.DeniedPatterns { _, err := path.Match(pattern, "abc") if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid file pattern filter %q", pattern)), util.I18nErrorFilePatternInvalid, ) } denied = append(denied, strings.ToLower(pattern)) } f.AllowedPatterns = util.RemoveDuplicates(allowed, false) f.DeniedPatterns = util.RemoveDuplicates(denied, false) filters = append(filters, f) filteredPaths = append(filteredPaths, cleanedPath) } baseFilters.FilePatterns = filters return nil } func checkEmptyFiltersStruct(filters *sdk.BaseUserFilters) { if len(filters.AllowedIP) == 0 { filters.AllowedIP = []string{} } if len(filters.DeniedIP) == 0 { filters.DeniedIP = []string{} } if len(filters.DeniedLoginMethods) == 0 { filters.DeniedLoginMethods = []string{} } if len(filters.DeniedProtocols) == 0 { filters.DeniedProtocols = []string{} } } func validateIPFilters(filters *sdk.BaseUserFilters) error { filters.DeniedIP = util.RemoveDuplicates(filters.DeniedIP, false) for _, IPMask := range filters.DeniedIP { _, _, err := net.ParseCIDR(IPMask) if err != nil { return util.NewValidationError(fmt.Sprintf("could not parse denied IP/Mask %q: %v", IPMask, err)) } } filters.AllowedIP = util.RemoveDuplicates(filters.AllowedIP, false) for _, IPMask := range filters.AllowedIP { _, _, err := net.ParseCIDR(IPMask) if err != nil { return util.NewValidationError(fmt.Sprintf("could not parse allowed IP/Mask %q: %v", IPMask, err)) } } return nil } func validateBandwidthLimit(bl sdk.BandwidthLimit) error { if len(bl.Sources) == 0 { return util.NewValidationError("no bandwidth limit source specified") } for _, source := range bl.Sources { _, _, err := net.ParseCIDR(source) if err != nil { return util.NewValidationError(fmt.Sprintf("could not parse bandwidth limit source %q: %v", source, err)) } } return nil } func validateBandwidthLimitsFilter(filters *sdk.BaseUserFilters) error { for idx, bandwidthLimit := range filters.BandwidthLimits { if err := validateBandwidthLimit(bandwidthLimit); err != nil { return err } if bandwidthLimit.DownloadBandwidth < 0 { filters.BandwidthLimits[idx].DownloadBandwidth = 0 } if bandwidthLimit.UploadBandwidth < 0 { filters.BandwidthLimits[idx].UploadBandwidth = 0 } } return nil } func updateFiltersValues(filters *sdk.BaseUserFilters) { if filters.StartDirectory != "" { filters.StartDirectory = util.CleanPath(filters.StartDirectory) if filters.StartDirectory == "/" { filters.StartDirectory = "" } } } func validateFilterProtocols(filters *sdk.BaseUserFilters) error { if len(filters.DeniedProtocols) >= len(ValidProtocols) { return util.NewValidationError("invalid denied_protocols") } for _, p := range filters.DeniedProtocols { if !slices.Contains(ValidProtocols, p) { return util.NewValidationError(fmt.Sprintf("invalid denied protocol %q", p)) } } for _, p := range filters.TwoFactorAuthProtocols { if !slices.Contains(MFAProtocols, p) { return util.NewValidationError(fmt.Sprintf("invalid two factor protocol %q", p)) } } return nil } func validateTLSCerts(certs []string) ([]string, error) { var validateCerts []string for idx, cert := range certs { if cert == "" { continue } derBlock, _ := pem.Decode([]byte(cert)) if derBlock == nil { return nil, util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid TLS certificate %d", idx)), util.I18nErrorInvalidTLSCert, ) } crt, err := x509.ParseCertificate(derBlock.Bytes) if err != nil { return nil, util.NewI18nError( util.NewValidationError(fmt.Sprintf("error parsing TLS certificate %d", idx)), util.I18nErrorInvalidTLSCert, ) } if crt.PublicKeyAlgorithm == x509.RSA { if rsaCert, ok := crt.PublicKey.(*rsa.PublicKey); ok { if size := rsaCert.N.BitLen(); size < 2048 { providerLog(logger.LevelError, "rsa cert with size %d not accepted, minimum 2048", size) return nil, util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid size %d for rsa cert at position %d, minimum 2048", size, idx)), util.I18nErrorKeySizeInvalid, ) } } } validateCerts = append(validateCerts, cert) } return validateCerts, nil } func validateBaseFilters(filters *sdk.BaseUserFilters) error { checkEmptyFiltersStruct(filters) if err := validateIPFilters(filters); err != nil { return util.NewI18nError(err, util.I18nErrorIPFiltersInvalid) } if err := validateBandwidthLimitsFilter(filters); err != nil { return util.NewI18nError(err, util.I18nErrorSourceBWLimitInvalid) } if len(filters.DeniedLoginMethods) >= len(ValidLoginMethods) { return util.NewValidationError("invalid denied_login_methods") } for _, loginMethod := range filters.DeniedLoginMethods { if !slices.Contains(ValidLoginMethods, loginMethod) { return util.NewValidationError(fmt.Sprintf("invalid login method: %q", loginMethod)) } } if err := validateFilterProtocols(filters); err != nil { return err } if filters.TLSUsername != "" { if !slices.Contains(validTLSUsernames, string(filters.TLSUsername)) { return util.NewValidationError(fmt.Sprintf("invalid TLS username: %q", filters.TLSUsername)) } } certs, err := validateTLSCerts(filters.TLSCerts) if err != nil { return err } filters.TLSCerts = certs for _, opts := range filters.WebClient { if !slices.Contains(sdk.WebClientOptions, opts) { return util.NewValidationError(fmt.Sprintf("invalid web client options %q", opts)) } } if filters.MaxSharesExpiration > 0 && filters.MaxSharesExpiration < filters.DefaultSharesExpiration { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("default shares expiration: %d must be less than or equal to max shares expiration: %d", filters.DefaultSharesExpiration, filters.MaxSharesExpiration)), util.I18nErrorShareExpirationInvalid, ) } updateFiltersValues(filters) if err := validateAccessTimeFilters(filters); err != nil { return err } return validateFiltersPatternExtensions(filters) } func isTimeOfDayValid(value string) bool { if len(value) != 5 { return false } parts := strings.Split(value, ":") if len(parts) != 2 { return false } hour, err := strconv.Atoi(parts[0]) if err != nil { return false } if hour < 0 || hour > 23 { return false } minute, err := strconv.Atoi(parts[1]) if err != nil { return false } if minute < 0 || minute > 59 { return false } return true } func validateAccessTimeFilters(filters *sdk.BaseUserFilters) error { for _, period := range filters.AccessTime { if period.DayOfWeek < int(time.Sunday) || period.DayOfWeek > int(time.Saturday) { return util.NewValidationError(fmt.Sprintf("invalid day of week: %d", period.DayOfWeek)) } if !isTimeOfDayValid(period.From) || !isTimeOfDayValid(period.To) { return util.NewI18nError( util.NewValidationError("invalid time of day. Supported format: HH:MM"), util.I18nErrorTimeOfDayInvalid, ) } if period.To <= period.From { return util.NewI18nError( util.NewValidationError("invalid time of day. The end time cannot be earlier than the start time"), util.I18nErrorTimeOfDayConflict, ) } } return nil } func validateCombinedUserFilters(user *User) error { if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) { return util.NewI18nError( util.NewValidationError("two-factor authentication cannot be disabled for a user with an active configuration"), util.I18nErrorDisableActive2FA, ) } if user.Filters.RequirePasswordChange && slices.Contains(user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) { return util.NewI18nError( util.NewValidationError("you cannot require password change and at the same time disallow it"), util.I18nErrorPwdChangeConflict, ) } if len(user.Filters.TwoFactorAuthProtocols) > 0 && slices.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) { return util.NewI18nError( util.NewValidationError("you cannot require two-factor authentication and at the same time disallow it"), util.I18nError2FAConflict, ) } return nil } func validateEmails(user *User) error { if user.Email != "" && !util.IsEmailValid(user.Email) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("email %q is not valid", user.Email)), util.I18nErrorInvalidEmail, ) } for _, email := range user.Filters.AdditionalEmails { if !util.IsEmailValid(email) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("email %q is not valid", email)), util.I18nErrorInvalidEmail, ) } } return nil } func validateBaseParams(user *User) error { if user.Username == "" { return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired) } if !util.IsNameValid(user.Username) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if err := checkReservedUsernames(user.Username); err != nil { return util.NewI18nError(err, util.I18nErrorReservedUsername) } if err := validateEmails(user); err != nil { return err } if config.NamingRules&1 == 0 && !usernameRegex.MatchString(user.Username) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("username %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", user.Username)), util.I18nErrorInvalidUser, ) } if user.hasRedactedSecret() { return util.NewValidationError("cannot save a user with a redacted secret") } if user.HomeDir == "" { return util.NewI18nError(util.NewValidationError("home_dir is mandatory"), util.I18nErrorHomeRequired) } // we can have users with no passwords and public keys, they can authenticate via SSH user certs or OIDC /*if user.Password == "" && len(user.PublicKeys) == 0 { return util.NewValidationError("please set a password or at least a public_key") }*/ if !filepath.IsAbs(user.HomeDir) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("home_dir must be an absolute path, actual value: %v", user.HomeDir)), util.I18nErrorHomeInvalid, ) } if user.DownloadBandwidth < 0 { user.DownloadBandwidth = 0 } if user.UploadBandwidth < 0 { user.UploadBandwidth = 0 } if user.TotalDataTransfer > 0 { // if a total data transfer is defined we reset the separate upload and download limits user.UploadDataTransfer = 0 user.DownloadDataTransfer = 0 } if user.Filters.IsAnonymous { user.setAnonymousSettings() } err := user.FsConfig.Validate(user.GetEncryptionAdditionalData()) if err != nil { return err } return nil } func hashPlainPassword(plainPwd string) (string, error) { if config.PasswordHashing.Algo == HashingAlgoBcrypt { pwd, err := bcrypt.GenerateFromPassword([]byte(plainPwd), config.PasswordHashing.BcryptOptions.Cost) if err != nil { return "", fmt.Errorf("bcrypt hashing error: %w", err) } return util.BytesToString(pwd), nil } pwd, err := argon2id.CreateHash(plainPwd, argon2Params) if err != nil { return "", fmt.Errorf("argon2ID hashing error: %w", err) } return pwd, nil } func createUserPasswordHash(user *User) error { if user.Password != "" && !user.IsPasswordHashed() { for _, g := range user.Groups { if g.Type == sdk.GroupTypePrimary { group, err := GroupExists(g.Name) if err != nil { return errors.New("unable to load group password policies") } if minEntropy := group.UserSettings.Filters.PasswordStrength; minEntropy > 0 { if err := passwordvalidator.Validate(user.Password, float64(minEntropy)); err != nil { return util.NewI18nError(util.NewValidationError(err.Error()), util.I18nErrorPasswordComplexity) } } } } if minEntropy := user.getMinPasswordEntropy(); minEntropy > 0 { if err := passwordvalidator.Validate(user.Password, minEntropy); err != nil { return util.NewI18nError(util.NewValidationError(err.Error()), util.I18nErrorPasswordComplexity) } } hashedPwd, err := hashPlainPassword(user.Password) if err != nil { return err } user.Password = hashedPwd user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now()) } return nil } // ValidateFolder returns an error if the folder is not valid // FIXME: this should be defined as Folder struct method func ValidateFolder(folder *vfs.BaseVirtualFolder) error { folder.FsConfig.SetEmptySecretsIfNil() if folder.Name == "" { return util.NewI18nError(util.NewValidationError("folder name is mandatory"), util.I18nErrorNameRequired) } if !util.IsNameValid(folder.Name) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if config.NamingRules&1 == 0 && !usernameRegex.MatchString(folder.Name) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("folder name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", folder.Name)), util.I18nErrorInvalidName, ) } if folder.FsConfig.Provider == sdk.LocalFilesystemProvider || folder.FsConfig.Provider == sdk.CryptedFilesystemProvider || folder.MappedPath != "" { cleanedMPath := filepath.Clean(folder.MappedPath) if !filepath.IsAbs(cleanedMPath) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid folder mapped path %q", folder.MappedPath)), util.I18nErrorInvalidHomeDir, ) } folder.MappedPath = cleanedMPath } if folder.HasRedactedSecret() { return errors.New("cannot save a folder with a redacted secret") } return folder.FsConfig.Validate(folder.GetEncryptionAdditionalData()) } // ValidateUser returns an error if the user is not valid // FIXME: this should be defined as User struct method func ValidateUser(user *User) error { user.OIDCCustomFields = nil user.HasPassword = false user.SetEmptySecretsIfNil() user.applyNamingRules() buildUserHomeDir(user) if err := validateBaseParams(user); err != nil { return err } if err := validateUserGroups(user); err != nil { return err } if err := validatePermissions(user); err != nil { return err } if err := validateUserTOTPConfig(&user.Filters.TOTPConfig, user.Username); err != nil { return util.NewI18nError(err, util.I18nError2FAInvalid) } if err := validateUserRecoveryCodes(user); err != nil { return util.NewI18nError(err, util.I18nErrorRecoveryCodesInvalid) } vfolders, err := validateAssociatedVirtualFolders(user.VirtualFolders) if err != nil { return err } user.VirtualFolders = vfolders if user.Status < 0 || user.Status > 1 { return util.NewValidationError(fmt.Sprintf("invalid user status: %v", user.Status)) } if err := createUserPasswordHash(user); err != nil { return err } if err := validatePublicKeys(user); err != nil { return err } if err := validateBaseFilters(&user.Filters.BaseUserFilters); err != nil { return err } if !user.HasExternalAuth() { user.Filters.ExternalAuthCacheTime = 0 } return validateCombinedUserFilters(user) } func isPasswordOK(user *User, password string) (bool, error) { if config.PasswordCaching { found, match := cachedUserPasswords.Check(user.Username, password, user.Password) if found { return match, nil } } match := false updatePwd := true var err error switch { case strings.HasPrefix(user.Password, bcryptPwdPrefix): if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { return match, ErrInvalidCredentials } match = true updatePwd = config.PasswordHashing.Algo != HashingAlgoBcrypt case strings.HasPrefix(user.Password, argonPwdPrefix): match, err = argon2id.ComparePasswordAndHash(password, user.Password) if err != nil { providerLog(logger.LevelError, "error comparing password with argon hash: %v", err) return match, err } updatePwd = config.PasswordHashing.Algo != HashingAlgoArgon2ID case util.IsStringPrefixInSlice(user.Password, unixPwdPrefixes): match, err = compareUnixPasswordAndHash(user, password) if err != nil { return match, err } case util.IsStringPrefixInSlice(user.Password, pbkdfPwdPrefixes): match, err = comparePbkdf2PasswordAndHash(password, user.Password) if err != nil { return match, err } case util.IsStringPrefixInSlice(user.Password, digestPwdPrefixes): match = compareDigestPasswordAndHash(user, password) } if err == nil && match { cachedUserPasswords.Add(user.Username, password, user.Password) if updatePwd { convertUserPassword(user.Username, password) } } return match, err } func convertUserPassword(username, plainPwd string) { hashedPwd, err := hashPlainPassword(plainPwd) if err == nil { err = provider.updateUserPassword(username, hashedPwd) } if err != nil { providerLog(logger.LevelWarn, "unable to convert password for user %s: %v", username, err) } else { providerLog(logger.LevelDebug, "password converted for user %s", username) } } func checkUserAndTLSCertificate(user *User, protocol string, tlsCert *x509.Certificate) (User, error) { err := user.LoadAndApplyGroupSettings() if err != nil { return *user, err } err = user.CheckLoginConditions() if err != nil { return *user, err } switch protocol { case protocolFTP, protocolWebDAV: for _, cert := range user.Filters.TLSCerts { derBlock, _ := pem.Decode(util.StringToBytes(cert)) if derBlock != nil && bytes.Equal(derBlock.Bytes, tlsCert.Raw) { return *user, nil } } if user.Filters.TLSUsername == sdk.TLSUsernameCN { if user.Username == tlsCert.Subject.CommonName { return *user, nil } return *user, fmt.Errorf("CN %q does not match username %q", tlsCert.Subject.CommonName, user.Username) } return *user, errors.New("TLS certificate is not valid") default: return *user, fmt.Errorf("certificate authentication is not supported for protocol %v", protocol) } } func checkUserAndPass(user *User, password, ip, protocol string) (User, error) { err := user.LoadAndApplyGroupSettings() if err != nil { return *user, err } err = user.CheckLoginConditions() if err != nil { return *user, err } if protocol != protocolHTTP && user.MustChangePassword() { return *user, errors.New("login not allowed, password change required") } if user.Filters.IsAnonymous { user.setAnonymousSettings() return *user, nil } password, err = checkUserPasscode(user, password, protocol) if err != nil { return *user, ErrInvalidCredentials } if user.Password == "" || strings.TrimSpace(password) == "" { return *user, errors.New("credentials cannot be null or empty") } if !user.Filters.Hooks.CheckPasswordDisabled { hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol) if err != nil { providerLog(logger.LevelDebug, "error executing check password hook for user %q, ip %v, protocol %v: %v", user.Username, ip, protocol, err) return *user, errors.New("unable to check credentials") } switch hookResponse.Status { case -1: // no hook configured case 1: providerLog(logger.LevelDebug, "password accepted by check password hook for user %q, ip %v, protocol %v", user.Username, ip, protocol) return *user, nil case 2: providerLog(logger.LevelDebug, "partial success from check password hook for user %q, ip %v, protocol %v", user.Username, ip, protocol) password = hookResponse.ToVerify default: providerLog(logger.LevelDebug, "password rejected by check password hook for user %q, ip %v, protocol %v, status: %v", user.Username, ip, protocol, hookResponse.Status) return *user, ErrInvalidCredentials } } match, err := isPasswordOK(user, password) if !match { err = ErrInvalidCredentials } return *user, err } func checkUserPasscode(user *User, password, protocol string) (string, error) { if user.Filters.TOTPConfig.Enabled { switch protocol { case protocolFTP: if slices.Contains(user.Filters.TOTPConfig.Protocols, protocol) { // the TOTP passcode has six digits pwdLen := len(password) if pwdLen < 7 { providerLog(logger.LevelDebug, "password len %v is too short to contain a passcode, user %q, protocol %v", pwdLen, user.Username, protocol) return "", util.NewValidationError("password too short, cannot contain the passcode") } err := user.Filters.TOTPConfig.Secret.TryDecrypt() if err != nil { providerLog(logger.LevelError, "unable to decrypt TOTP secret for user %q, protocol %v, err: %v", user.Username, protocol, err) return "", err } pwd := password[0:(pwdLen - 6)] passcode := password[(pwdLen - 6):] match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, passcode, user.Filters.TOTPConfig.Secret.GetPayload()) if !match || err != nil { providerLog(logger.LevelWarn, "invalid passcode for user %q, protocol %v, err: %v", user.Username, protocol, err) return "", util.NewValidationError("invalid passcode") } return pwd, nil } } } return password, nil } func checkUserAndPubKey(user *User, pubKey []byte, isSSHCert bool) (User, string, error) { err := user.LoadAndApplyGroupSettings() if err != nil { return *user, "", err } err = user.CheckLoginConditions() if err != nil { return *user, "", err } if isSSHCert { return *user, "", nil } if len(user.PublicKeys) == 0 { return *user, "", ErrInvalidCredentials } for idx, key := range user.PublicKeys { storedKey, comment, _, _, err := ssh.ParseAuthorizedKey(util.StringToBytes(key)) if err != nil { providerLog(logger.LevelError, "error parsing stored public key %d for user %s: %v", idx, user.Username, err) return *user, "", err } if bytes.Equal(storedKey.Marshal(), pubKey) { return *user, fmt.Sprintf("%s:%s", ssh.FingerprintSHA256(storedKey), comment), nil } } return *user, "", ErrInvalidCredentials } func compareDigestPasswordAndHash(user *User, password string) bool { if strings.HasPrefix(user.Password, md5DigestPwdPrefix) { h := md5.New() h.Write([]byte(password)) return fmt.Sprintf("%s%x", md5DigestPwdPrefix, h.Sum(nil)) == user.Password } if strings.HasPrefix(user.Password, sha256DigestPwdPrefix) { h := sha256.New() h.Write([]byte(password)) return fmt.Sprintf("%s%x", sha256DigestPwdPrefix, h.Sum(nil)) == user.Password } if strings.HasPrefix(user.Password, sha512DigestPwdPrefix) { h := sha512.New() h.Write([]byte(password)) return fmt.Sprintf("%s%x", sha512DigestPwdPrefix, h.Sum(nil)) == user.Password } return false } func compareUnixPasswordAndHash(user *User, password string) (bool, error) { if strings.HasPrefix(user.Password, yescryptPwdPrefix) { return compareYescryptPassword(user.Password, password) } var crypter crypt.Crypter if strings.HasPrefix(user.Password, sha512cryptPwdPrefix) { crypter = sha512_crypt.New() } else if strings.HasPrefix(user.Password, sha256cryptPwdPrefix) { crypter = sha256_crypt.New() } else if strings.HasPrefix(user.Password, md5cryptPwdPrefix) { crypter = md5_crypt.New() } else if strings.HasPrefix(user.Password, md5cryptApr1PwdPrefix) { crypter = apr1_crypt.New() } else { return false, errors.New("unix crypt: invalid or unsupported hash format") } if err := crypter.Verify(user.Password, []byte(password)); err != nil { return false, err } return true, nil } func comparePbkdf2PasswordAndHash(password, hashedPassword string) (bool, error) { vals := strings.Split(hashedPassword, "$") if len(vals) != 5 { return false, fmt.Errorf("pbkdf2: hash is not in the correct format") } iterations, err := strconv.Atoi(vals[2]) if err != nil { return false, err } expected, err := base64.StdEncoding.DecodeString(vals[4]) if err != nil { return false, err } var salt []byte if util.IsStringPrefixInSlice(hashedPassword, pbkdfPwdB64SaltPrefixes) { salt, err = base64.StdEncoding.DecodeString(vals[3]) if err != nil { return false, err } } else { salt = []byte(vals[3]) } var hashFunc func() hash.Hash if strings.HasPrefix(hashedPassword, pbkdf2SHA256Prefix) || strings.HasPrefix(hashedPassword, pbkdf2SHA256B64SaltPrefix) { hashFunc = sha256.New } else if strings.HasPrefix(hashedPassword, pbkdf2SHA512Prefix) { hashFunc = sha512.New } else if strings.HasPrefix(hashedPassword, pbkdf2SHA1Prefix) { hashFunc = sha1.New } else { return false, fmt.Errorf("pbkdf2: invalid or unsupported hash format %v", vals[1]) } df := pbkdf2.Key([]byte(password), salt, iterations, len(expected), hashFunc) return subtle.ConstantTimeCompare(df, expected) == 1, nil } func getSSLMode() string { switch config.Driver { case PGSQLDataProviderName, CockroachDataProviderName: switch config.SSLMode { case 0: return "disable" case 1: return "require" case 2: return "verify-ca" case 3: return "verify-full" case 4: return "prefer" case 5: return "allow" } case MySQLDataProviderName: if config.requireCustomTLSForMySQL() { return "custom" } switch config.SSLMode { case 0: return "false" case 1: return "true" case 2: return "skip-verify" case 3: return "preferred" } } return "" } func terminateInteractiveAuthProgram(cmd *exec.Cmd, isFinished bool) { if isFinished { return } providerLog(logger.LevelInfo, "kill interactive auth program after an unexpected error") err := cmd.Process.Kill() if err != nil { providerLog(logger.LevelDebug, "error killing interactive auth program: %v", err) } } func sendKeyboardAuthHTTPReq(url string, request *plugin.KeyboardAuthRequest) (*plugin.KeyboardAuthResponse, error) { reqAsJSON, err := json.Marshal(request) if err != nil { providerLog(logger.LevelError, "error serializing keyboard interactive auth request: %v", err) return nil, err } resp, err := httpclient.Post(url, "application/json", bytes.NewBuffer(reqAsJSON)) if err != nil { providerLog(logger.LevelError, "error getting keyboard interactive auth hook HTTP response: %v", err) return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("wrong keyboard interactive auth http status code: %v, expected 200", resp.StatusCode) } var response plugin.KeyboardAuthResponse err = render.DecodeJSON(resp.Body, &response) return &response, err } func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractiveChallenge, ip, protocol string, isPartialAuth bool, ) (int, error) { if err := user.LoadAndApplyGroupSettings(); err != nil { return 0, err } hasSecondFactor := user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) if !isPartialAuth || !hasSecondFactor { answers, err := client("", "", []string{"Password: "}, []bool{false}) if err != nil { return 0, err } if len(answers) != 1 { return 0, fmt.Errorf("unexpected number of answers: %d", len(answers)) } _, err = checkUserAndPass(user, answers[0], ip, protocol) if err != nil { return 0, err } } return checkKeyboardInteractiveSecondFactor(user, client, protocol) } func checkKeyboardInteractiveSecondFactor(user *User, client ssh.KeyboardInteractiveChallenge, protocol string) (int, error) { if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) { return 1, nil } err := user.Filters.TOTPConfig.Secret.TryDecrypt() if err != nil { providerLog(logger.LevelError, "unable to decrypt TOTP secret for user %q, protocol %v, err: %v", user.Username, protocol, err) return 0, err } answers, err := client("", "", []string{"Authentication code: "}, []bool{false}) if err != nil { return 0, err } if len(answers) != 1 { return 0, fmt.Errorf("unexpected number of answers: %v", len(answers)) } match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, answers[0], user.Filters.TOTPConfig.Secret.GetPayload()) if !match || err != nil { providerLog(logger.LevelWarn, "invalid passcode for user %q, protocol %v, err: %v", user.Username, protocol, err) return 0, util.NewValidationError("invalid passcode") } return 1, nil } func executeKeyboardInteractivePlugin(user *User, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { authResult := 0 requestID := xid.New().String() authStep := 1 req := &plugin.KeyboardAuthRequest{ Username: user.Username, IP: ip, Password: user.Password, RequestID: requestID, Step: authStep, } var response *plugin.KeyboardAuthResponse var err error for { response, err = plugin.Handler.ExecuteKeyboardInteractiveStep(req) if err != nil { return authResult, err } if response.AuthResult != 0 { return response.AuthResult, err } if err = response.Validate(); err != nil { providerLog(logger.LevelInfo, "invalid response from keyboard interactive plugin: %v", err) return authResult, err } answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol) if err != nil { return authResult, err } authStep++ req = &plugin.KeyboardAuthRequest{ RequestID: requestID, Step: authStep, Username: user.Username, Password: user.Password, Answers: answers, Questions: response.Questions, } } } func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { authResult := 0 requestID := xid.New().String() authStep := 1 req := &plugin.KeyboardAuthRequest{ Username: user.Username, IP: ip, Password: user.Password, RequestID: requestID, Step: authStep, } var response *plugin.KeyboardAuthResponse var err error for { response, err = sendKeyboardAuthHTTPReq(authHook, req) if err != nil { return authResult, err } if response.AuthResult != 0 { return response.AuthResult, err } if err = response.Validate(); err != nil { providerLog(logger.LevelInfo, "invalid response from keyboard interactive http hook: %v", err) return authResult, err } answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol) if err != nil { return authResult, err } authStep++ req = &plugin.KeyboardAuthRequest{ RequestID: requestID, Step: authStep, Username: user.Username, Password: user.Password, Answers: answers, Questions: response.Questions, } } } func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, response *plugin.KeyboardAuthResponse, user *User, ip, protocol string, ) ([]string, error) { questions := response.Questions answers, err := client("", response.Instruction, questions, response.Echos) if err != nil { providerLog(logger.LevelInfo, "error getting interactive auth client response: %v", err) return answers, err } if len(answers) != len(questions) { err = fmt.Errorf("client answers does not match questions, expected: %v actual: %v", questions, answers) providerLog(logger.LevelInfo, "keyboard interactive auth error: %v", err) return answers, err } if len(answers) == 1 && response.CheckPwd > 0 { if response.CheckPwd == 2 { if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) { providerLog(logger.LevelInfo, "keyboard interactive auth error: unable to check TOTP passcode, TOTP is not enabled for user %q", user.Username) return answers, errors.New("TOTP not enabled for SSH protocol") } err := user.Filters.TOTPConfig.Secret.TryDecrypt() if err != nil { providerLog(logger.LevelError, "unable to decrypt TOTP secret for user %q, protocol %v, err: %v", user.Username, protocol, err) return answers, fmt.Errorf("unable to decrypt TOTP secret: %w", err) } match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, answers[0], user.Filters.TOTPConfig.Secret.GetPayload()) if !match || err != nil { providerLog(logger.LevelInfo, "keyboard interactive auth error: unable to validate passcode for user %q, match? %v, err: %v", user.Username, match, err) return answers, errors.New("unable to validate TOTP passcode") } } else { _, err = checkUserAndPass(user, answers[0], ip, protocol) providerLog(logger.LevelInfo, "interactive auth hook requested password validation for user %q, validation error: %v", user.Username, err) if err != nil { return answers, err } } answers[0] = "OK" } return answers, err } func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge, response *plugin.KeyboardAuthResponse, user *User, stdin io.WriteCloser, ip, protocol string, ) error { answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol) if err != nil { return err } for _, answer := range answers { if runtime.GOOS == "windows" { answer += "\r" } answer += "\n" _, err = stdin.Write([]byte(answer)) if err != nil { providerLog(logger.LevelError, "unable to write client answer to keyboard interactive program: %v", err) return err } } return nil } func executeKeyboardInteractiveProgram(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { authResult := 0 timeout, env, args := command.GetConfig(authHook, command.HookKeyboardInteractive) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, authHook, args...) cmd.Env = append(env, fmt.Sprintf("SFTPGO_AUTHD_USERNAME=%s", user.Username), fmt.Sprintf("SFTPGO_AUTHD_IP=%s", ip), fmt.Sprintf("SFTPGO_AUTHD_PASSWORD=%s", user.Password)) stdout, err := cmd.StdoutPipe() if err != nil { return authResult, err } stdin, err := cmd.StdinPipe() if err != nil { return authResult, err } err = cmd.Start() if err != nil { return authResult, err } var once sync.Once scanner := bufio.NewScanner(stdout) for scanner.Scan() { var response plugin.KeyboardAuthResponse err = json.Unmarshal(scanner.Bytes(), &response) if err != nil { providerLog(logger.LevelInfo, "interactive auth error parsing response: %v", err) once.Do(func() { terminateInteractiveAuthProgram(cmd, false) }) break } if response.AuthResult != 0 { authResult = response.AuthResult break } if err = response.Validate(); err != nil { providerLog(logger.LevelInfo, "invalid response from keyboard interactive program: %v", err) once.Do(func() { terminateInteractiveAuthProgram(cmd, false) }) break } go func() { err := handleProgramInteractiveQuestions(client, &response, user, stdin, ip, protocol) if err != nil { once.Do(func() { terminateInteractiveAuthProgram(cmd, false) }) } }() } stdin.Close() once.Do(func() { terminateInteractiveAuthProgram(cmd, true) }) go func() { _, err := cmd.Process.Wait() if err != nil { providerLog(logger.LevelWarn, "error waiting for %q process to exit: %v", authHook, err) } }() return authResult, err } func doKeyboardInteractiveAuth(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string, isPartialAuth bool, ) (User, error) { if err := user.LoadAndApplyGroupSettings(); err != nil { return *user, err } var authResult int var err error if !user.Filters.Hooks.ExternalAuthDisabled { if plugin.Handler.HasAuthScope(plugin.AuthScopeKeyboardInteractive) { authResult, err = executeKeyboardInteractivePlugin(user, client, ip, protocol) if authResult == 1 && err == nil { authResult, err = checkKeyboardInteractiveSecondFactor(user, client, protocol) } } else if authHook != "" { if strings.HasPrefix(authHook, "http") { authResult, err = executeKeyboardInteractiveHTTPHook(user, authHook, client, ip, protocol) } else { authResult, err = executeKeyboardInteractiveProgram(user, authHook, client, ip, protocol) } } else { authResult, err = doBuiltinKeyboardInteractiveAuth(user, client, ip, protocol, isPartialAuth) } } else { authResult, err = doBuiltinKeyboardInteractiveAuth(user, client, ip, protocol, isPartialAuth) } if err != nil { return *user, err } if authResult != 1 { return *user, fmt.Errorf("keyboard interactive auth failed, result: %v", authResult) } err = user.CheckLoginConditions() if err != nil { return *user, err } return *user, nil } func isCheckPasswordHookDefined(protocol string) bool { if config.CheckPasswordHook == "" { return false } if config.CheckPasswordScope == 0 { return true } switch protocol { case protocolSSH: return config.CheckPasswordScope&1 != 0 case protocolFTP: return config.CheckPasswordScope&2 != 0 case protocolWebDAV: return config.CheckPasswordScope&4 != 0 default: return false } } func getPasswordHookResponse(username, password, ip, protocol string) ([]byte, error) { if strings.HasPrefix(config.CheckPasswordHook, "http") { var result []byte req := checkPasswordRequest{ Username: username, Password: password, IP: ip, Protocol: protocol, } reqAsJSON, err := json.Marshal(req) if err != nil { return result, err } resp, err := httpclient.Post(config.CheckPasswordHook, "application/json", bytes.NewBuffer(reqAsJSON)) if err != nil { providerLog(logger.LevelError, "error getting check password hook response: %v", err) return result, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return result, fmt.Errorf("wrong http status code from chek password hook: %v, expected 200", resp.StatusCode) } return io.ReadAll(io.LimitReader(resp.Body, maxHookResponseSize)) } timeout, env, args := command.GetConfig(config.CheckPasswordHook, command.HookCheckPassword) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, config.CheckPasswordHook, args...) cmd.Env = append(env, fmt.Sprintf("SFTPGO_AUTHD_USERNAME=%s", username), fmt.Sprintf("SFTPGO_AUTHD_PASSWORD=%s", password), fmt.Sprintf("SFTPGO_AUTHD_IP=%s", ip), fmt.Sprintf("SFTPGO_AUTHD_PROTOCOL=%s", protocol), ) return getCmdOutput(cmd, "check_password_hook") } func executeCheckPasswordHook(username, password, ip, protocol string) (checkPasswordResponse, error) { var response checkPasswordResponse if !isCheckPasswordHookDefined(protocol) { response.Status = -1 return response, nil } startTime := time.Now() out, err := getPasswordHookResponse(username, password, ip, protocol) providerLog(logger.LevelDebug, "check password hook executed, error: %v, elapsed: %v", err, time.Since(startTime)) if err != nil { return response, err } err = json.Unmarshal(out, &response) return response, err } func getPreLoginHookResponse(loginMethod, ip, protocol string, userAsJSON []byte) ([]byte, error) { if strings.HasPrefix(config.PreLoginHook, "http") { var url *url.URL var result []byte url, err := url.Parse(config.PreLoginHook) if err != nil { providerLog(logger.LevelError, "invalid url for pre-login hook %q, error: %v", config.PreLoginHook, err) return result, err } q := url.Query() q.Add("login_method", loginMethod) q.Add("ip", ip) q.Add("protocol", protocol) url.RawQuery = q.Encode() resp, err := httpclient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) if err != nil { providerLog(logger.LevelWarn, "error getting pre-login hook response: %v", err) return result, err } defer resp.Body.Close() if resp.StatusCode == http.StatusNoContent { return result, nil } if resp.StatusCode != http.StatusOK { return result, fmt.Errorf("wrong pre-login hook http status code: %v, expected 200", resp.StatusCode) } return io.ReadAll(io.LimitReader(resp.Body, maxHookResponseSize)) } timeout, env, args := command.GetConfig(config.PreLoginHook, command.HookPreLogin) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, config.PreLoginHook, args...) cmd.Env = append(env, fmt.Sprintf("SFTPGO_LOGIND_USER=%s", userAsJSON), fmt.Sprintf("SFTPGO_LOGIND_METHOD=%s", loginMethod), fmt.Sprintf("SFTPGO_LOGIND_IP=%s", ip), fmt.Sprintf("SFTPGO_LOGIND_PROTOCOL=%s", protocol), ) return getCmdOutput(cmd, "pre_login_hook") } func executePreLoginHook(username, loginMethod, ip, protocol string, oidcTokenFields *map[string]any) (User, error) { var user User u, mergedUser, userAsJSON, err := getUserAndJSONForHook(username, oidcTokenFields) if err != nil { return u, err } if mergedUser.Filters.Hooks.PreLoginDisabled { return u, nil } startTime := time.Now() out, err := getPreLoginHookResponse(loginMethod, ip, protocol, userAsJSON) if err != nil { return u, fmt.Errorf("pre-login hook error: %v, username %q, ip %v, protocol %v elapsed %v", err, username, ip, protocol, time.Since(startTime)) } providerLog(logger.LevelDebug, "pre-login hook completed, elapsed: %s", time.Since(startTime)) if util.IsByteArrayEmpty(out) { providerLog(logger.LevelDebug, "empty response from pre-login hook, no modification requested for user %q id: %d", username, u.ID) if u.ID == 0 { return u, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } return u, nil } err = json.Unmarshal(out, &user) if err != nil { return u, fmt.Errorf("invalid pre-login hook response %q, error: %v", out, err) } if u.ID > 0 { user.ID = u.ID user.UsedQuotaSize = u.UsedQuotaSize user.UsedQuotaFiles = u.UsedQuotaFiles user.UsedUploadDataTransfer = u.UsedUploadDataTransfer user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer user.LastQuotaUpdate = u.LastQuotaUpdate user.LastLogin = u.LastLogin user.LastPasswordChange = u.LastPasswordChange user.FirstDownload = u.FirstDownload user.FirstUpload = u.FirstUpload // preserve TOTP config and recovery codes user.Filters.TOTPConfig = u.Filters.TOTPConfig user.Filters.RecoveryCodes = u.Filters.RecoveryCodes if err := provider.updateUser(&user); err != nil { return u, err } } else { if err := provider.addUser(&user); err != nil { return u, err } } user, err = provider.userExists(user.Username, "") if err != nil { return u, err } providerLog(logger.LevelDebug, "user %q added/updated from pre-login hook response, id: %d", username, u.ID) if u.ID > 0 { webDAVUsersCache.swap(&user, "") } return user, nil } // ExecutePostLoginHook executes the post login hook if defined func ExecutePostLoginHook(user *User, loginMethod, ip, protocol string, err error) { if config.PostLoginHook == "" { return } if config.PostLoginScope == 1 && err == nil { return } if config.PostLoginScope == 2 && err != nil { return } go func() { actionsConcurrencyGuard <- struct{}{} defer func() { <-actionsConcurrencyGuard }() status := "0" if err == nil { status = "1" } user.PrepareForRendering() userAsJSON, err := json.Marshal(user) if err != nil { providerLog(logger.LevelError, "error serializing user in post login hook: %v", err) return } if strings.HasPrefix(config.PostLoginHook, "http") { var url *url.URL url, err := url.Parse(config.PostLoginHook) if err != nil { providerLog(logger.LevelDebug, "Invalid post-login hook %q", config.PostLoginHook) return } q := url.Query() q.Add("login_method", loginMethod) q.Add("ip", ip) q.Add("protocol", protocol) q.Add("status", status) url.RawQuery = q.Encode() startTime := time.Now() respCode := 0 resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) if err == nil { respCode = resp.StatusCode resp.Body.Close() } providerLog(logger.LevelDebug, "post login hook executed for user %q, ip %v, protocol %v, response code: %v, elapsed: %v err: %v", user.Username, ip, protocol, respCode, time.Since(startTime), err) return } timeout, env, args := command.GetConfig(config.PostLoginHook, command.HookPostLogin) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, config.PostLoginHook, args...) cmd.Env = append(env, fmt.Sprintf("SFTPGO_LOGIND_USER=%s", userAsJSON), fmt.Sprintf("SFTPGO_LOGIND_IP=%s", ip), fmt.Sprintf("SFTPGO_LOGIND_METHOD=%s", loginMethod), fmt.Sprintf("SFTPGO_LOGIND_STATUS=%s", status), fmt.Sprintf("SFTPGO_LOGIND_PROTOCOL=%s", protocol)) startTime := time.Now() err = cmd.Run() providerLog(logger.LevelDebug, "post login hook executed for user %q, ip %v, protocol %v, elapsed %v err: %v", user.Username, ip, protocol, time.Since(startTime), err) }() } func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, protocol string, cert *x509.Certificate, user User, ) ([]byte, error) { var tlsCert string if cert != nil { var err error tlsCert, err = util.EncodeTLSCertToPem(cert) if err != nil { return nil, err } } if strings.HasPrefix(config.ExternalAuthHook, "http") { var result []byte authRequest := make(map[string]any) authRequest["username"] = username authRequest["ip"] = ip authRequest["password"] = password authRequest["public_key"] = pkey authRequest["protocol"] = protocol authRequest["keyboard_interactive"] = keyboardInteractive authRequest["tls_cert"] = tlsCert if user.ID > 0 { authRequest["user"] = user } authRequestAsJSON, err := json.Marshal(authRequest) if err != nil { providerLog(logger.LevelError, "error serializing external auth request: %v", err) return result, err } resp, err := httpclient.Post(config.ExternalAuthHook, "application/json", bytes.NewBuffer(authRequestAsJSON)) if err != nil { providerLog(logger.LevelWarn, "error getting external auth hook HTTP response: %v", err) return result, err } defer resp.Body.Close() providerLog(logger.LevelDebug, "external auth hook executed, response code: %v", resp.StatusCode) if resp.StatusCode != http.StatusOK { return result, fmt.Errorf("wrong external auth http status code: %v, expected 200", resp.StatusCode) } return io.ReadAll(io.LimitReader(resp.Body, maxHookResponseSize)) } var userAsJSON []byte var err error if user.ID > 0 { userAsJSON, err = json.Marshal(user) if err != nil { return nil, fmt.Errorf("unable to serialize user as JSON: %w", err) } } timeout, env, args := command.GetConfig(config.ExternalAuthHook, command.HookExternalAuth) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() cmd := exec.CommandContext(ctx, config.ExternalAuthHook, args...) cmd.Env = append(env, fmt.Sprintf("SFTPGO_AUTHD_USERNAME=%s", username), fmt.Sprintf("SFTPGO_AUTHD_USER=%s", userAsJSON), fmt.Sprintf("SFTPGO_AUTHD_IP=%s", ip), fmt.Sprintf("SFTPGO_AUTHD_PASSWORD=%s", password), fmt.Sprintf("SFTPGO_AUTHD_PUBLIC_KEY=%s", pkey), fmt.Sprintf("SFTPGO_AUTHD_PROTOCOL=%s", protocol), fmt.Sprintf("SFTPGO_AUTHD_TLS_CERT=%s", strings.ReplaceAll(tlsCert, "\n", "\\n")), fmt.Sprintf("SFTPGO_AUTHD_KEYBOARD_INTERACTIVE=%v", keyboardInteractive)) return getCmdOutput(cmd, "external_auth_hook") } func updateUserFromExtAuthResponse(user *User, password, pkey string) { if password != "" { user.Password = password } if pkey != "" && !util.IsStringPrefixInSlice(pkey, user.PublicKeys) { user.PublicKeys = append(user.PublicKeys, pkey) } user.LastPasswordChange = 0 } func checkPasswordAfterEmptyExtAuthResponse(user *User, plainPwd, protocol string) error { if plainPwd == "" { return nil } match, err := isPasswordOK(user, plainPwd) if match && err == nil { return nil } hashedPwd, err := hashPlainPassword(plainPwd) if err != nil { providerLog(logger.LevelError, "unable to hash password for user %q after empty external response: %v", user.Username, err) return err } err = provider.updateUserPassword(user.Username, hashedPwd) if err != nil { providerLog(logger.LevelError, "unable to update password for user %q after empty external response: %v", user.Username, err) } user.Password = hashedPwd cachedUserPasswords.Add(user.Username, plainPwd, user.Password) if protocol != protocolWebDAV { webDAVUsersCache.swap(user, plainPwd) } providerLog(logger.LevelDebug, "updated password for user %q after empty external auth response", user.Username) return nil } func doExternalAuth(username, password string, pubKey []byte, keyboardInteractive, ip, protocol string, tlsCert *x509.Certificate, ) (User, error) { var user User u, mergedUser, err := getUserForHook(username, nil) if err != nil { return user, err } if mergedUser.skipExternalAuth() { return u, nil } pkey, err := util.GetSSHPublicKeyAsString(pubKey) if err != nil { return user, err } startTime := time.Now() out, err := getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, protocol, tlsCert, u) if err != nil { return user, fmt.Errorf("external auth error for user %q, elapsed: %s: %w", username, time.Since(startTime), err) } providerLog(logger.LevelDebug, "external auth completed for user %q, elapsed: %s", username, time.Since(startTime)) if util.IsByteArrayEmpty(out) { providerLog(logger.LevelDebug, "empty response from external hook, no modification requested for user %q, id: %d", username, u.ID) if u.ID == 0 { return u, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } err = checkPasswordAfterEmptyExtAuthResponse(&u, password, protocol) return u, err } err = json.Unmarshal(out, &user) if err != nil { return user, fmt.Errorf("invalid external auth response: %v", err) } // an empty username means authentication failure if user.Username == "" { return user, ErrInvalidCredentials } updateUserFromExtAuthResponse(&user, password, pkey) // some users want to map multiple login usernames with a single SFTPGo account // for example an SFTP user logins using "user1" or "user2" and the external auth // returns "user" in both cases, so we use the username returned from // external auth and not the one used to login if user.Username != username { u, err = provider.userExists(user.Username, "") } if u.ID > 0 && err == nil { user.ID = u.ID user.UsedQuotaSize = u.UsedQuotaSize user.UsedQuotaFiles = u.UsedQuotaFiles user.UsedUploadDataTransfer = u.UsedUploadDataTransfer user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer user.LastQuotaUpdate = u.LastQuotaUpdate user.LastLogin = u.LastLogin user.LastPasswordChange = u.LastPasswordChange user.FirstDownload = u.FirstDownload user.FirstUpload = u.FirstUpload user.CreatedAt = u.CreatedAt user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) // preserve TOTP config and recovery codes user.Filters.TOTPConfig = u.Filters.TOTPConfig user.Filters.RecoveryCodes = u.Filters.RecoveryCodes user, err = updateUserAfterExternalAuth(&user) if err == nil { if protocol != protocolWebDAV { webDAVUsersCache.swap(&user, password) } cachedUserPasswords.Add(user.Username, password, user.Password) } return user, err } err = provider.addUser(&user) if err != nil { return user, err } return provider.userExists(user.Username, "") } func doPluginAuth(username, password string, pubKey []byte, ip, protocol string, tlsCert *x509.Certificate, authScope int, ) (User, error) { var user User u, mergedUser, userAsJSON, err := getUserAndJSONForHook(username, nil) if err != nil { return user, err } if mergedUser.skipExternalAuth() { return u, nil } pkey, err := util.GetSSHPublicKeyAsString(pubKey) if err != nil { return user, err } startTime := time.Now() out, err := plugin.Handler.Authenticate(username, password, ip, protocol, pkey, tlsCert, authScope, userAsJSON) if err != nil { return user, fmt.Errorf("plugin auth error for user %q: %v, elapsed: %v, auth scope: %d", username, err, time.Since(startTime), authScope) } providerLog(logger.LevelDebug, "plugin auth completed for user %q, elapsed: %v, auth scope: %d", username, time.Since(startTime), authScope) if util.IsByteArrayEmpty(out) { providerLog(logger.LevelDebug, "empty response from plugin auth, no modification requested for user %q id: %d, auth scope: %d", username, u.ID, authScope) if u.ID == 0 { return u, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } err = checkPasswordAfterEmptyExtAuthResponse(&u, password, protocol) return u, err } err = json.Unmarshal(out, &user) if err != nil { return user, fmt.Errorf("invalid plugin auth response: %v", err) } updateUserFromExtAuthResponse(&user, password, pkey) if u.ID > 0 { user.ID = u.ID user.UsedQuotaSize = u.UsedQuotaSize user.UsedQuotaFiles = u.UsedQuotaFiles user.UsedUploadDataTransfer = u.UsedUploadDataTransfer user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer user.LastQuotaUpdate = u.LastQuotaUpdate user.LastLogin = u.LastLogin user.LastPasswordChange = u.LastPasswordChange user.FirstDownload = u.FirstDownload user.FirstUpload = u.FirstUpload // preserve TOTP config and recovery codes user.Filters.TOTPConfig = u.Filters.TOTPConfig user.Filters.RecoveryCodes = u.Filters.RecoveryCodes user, err = updateUserAfterExternalAuth(&user) if err == nil { if protocol != protocolWebDAV { webDAVUsersCache.swap(&user, password) } cachedUserPasswords.Add(user.Username, password, user.Password) } return user, err } err = provider.addUser(&user) if err != nil { return user, err } return provider.userExists(user.Username, "") } func updateUserAfterExternalAuth(user *User) (User, error) { if err := provider.updateUser(user); err != nil { return *user, err } return provider.userExists(user.Username, "") } func getUserForHook(username string, oidcTokenFields *map[string]any) (User, User, error) { u, err := provider.userExists(username, "") if err != nil { if !errors.Is(err, util.ErrNotFound) { return u, u, err } u = User{ BaseUser: sdk.BaseUser{ ID: 0, Username: username, }, } } mergedUser := u.getACopy() err = mergedUser.LoadAndApplyGroupSettings() if err != nil { return u, mergedUser, err } u.OIDCCustomFields = oidcTokenFields return u, mergedUser, err } func getUserAndJSONForHook(username string, oidcTokenFields *map[string]any) (User, User, []byte, error) { u, mergedUser, err := getUserForHook(username, oidcTokenFields) if err != nil { return u, mergedUser, nil, err } userAsJSON, err := json.Marshal(u) if err != nil { return u, mergedUser, userAsJSON, err } return u, mergedUser, userAsJSON, err } func isLastActivityRecent(lastActivity int64, minDelay time.Duration) bool { lastActivityTime := util.GetTimeFromMsecSinceEpoch(lastActivity) diff := -time.Until(lastActivityTime) if diff < -10*time.Second { return false } return diff < minDelay } func isExternalAuthConfigured(loginMethod string) bool { if config.ExternalAuthHook != "" { if config.ExternalAuthScope == 0 { return true } switch loginMethod { case LoginMethodPassword: return config.ExternalAuthScope&1 != 0 case LoginMethodTLSCertificate: return config.ExternalAuthScope&8 != 0 case LoginMethodTLSCertificateAndPwd: return config.ExternalAuthScope&1 != 0 || config.ExternalAuthScope&8 != 0 } } switch loginMethod { case LoginMethodPassword: return plugin.Handler.HasAuthScope(plugin.AuthScopePassword) case LoginMethodTLSCertificate: return plugin.Handler.HasAuthScope(plugin.AuthScopeTLSCertificate) case LoginMethodTLSCertificateAndPwd: return plugin.Handler.HasAuthScope(plugin.AuthScopePassword) || plugin.Handler.HasAuthScope(plugin.AuthScopeTLSCertificate) default: return false } } func replaceTemplateVars(input string) string { var result strings.Builder i := 0 for i < len(input) { if i+2 <= len(input) && input[i:i+2] == "{{" { if i+2 < len(input) { nextChar := input[i+2] if nextChar == ' ' || nextChar == '.' || nextChar == '-' { // Don't replace if followed by space, dot or minus. result.WriteString("{{") i += 2 continue } } // Find the closing "}}" closing := strings.Index(input[i:], "}}") if closing != -1 { // Replace with {{. only if it's a proper template variable. result.WriteString("{{.") result.WriteString(input[i+2 : i+closing]) result.WriteString("}}") i += closing + 2 continue } } result.WriteByte(input[i]) i++ } return result.String() } func updateEventActionPlaceholders(actions []BaseEventAction) ([]BaseEventAction, error) { var result []BaseEventAction for _, action := range actions { options, err := json.Marshal(action.Options) if err != nil { return nil, err } convertedOptions := replaceTemplateVars(string(options)) var opts BaseEventActionOptions err = json.Unmarshal([]byte(convertedOptions), &opts) if err != nil { return nil, err } action.Options = opts result = append(result, action) } return result, nil } func getConfigPath(name, configDir string) string { if !util.IsFileInputValid(name) { return "" } if name != "" && !filepath.IsAbs(name) { return filepath.Join(configDir, name) } return name } func checkReservedUsernames(username string) error { if slices.Contains(reservedUsers, username) { return util.NewValidationError("this username is reserved") } return nil } func errSchemaVersionTooOld(version int) error { return fmt.Errorf("database schema version %d is too old, please see the upgrading docs: https://docs.sftpgo.com/latest/data-provider/#upgrading", version) } func getCmdOutput(cmd *exec.Cmd, sender string) ([]byte, error) { var stdout bytes.Buffer cmd.Stdout = &stdout stderr, err := cmd.StderrPipe() if err != nil { return nil, err } err = cmd.Start() if err != nil { return nil, err } scanner := bufio.NewScanner(stderr) go func() { for scanner.Scan() { if out := scanner.Text(); out != "" { logger.Log(logger.LevelWarn, sender, "", "%s", out) } } if err := scanner.Err(); err != nil { logger.Log(logger.LevelError, sender, "", "error reading stderr: %v", err) } }() err = cmd.Wait() return stdout.Bytes(), err } func providerLog(level logger.LogLevel, format string, v ...any) { logger.Log(level, logSender, "", format, v...) } ================================================ FILE: internal/dataprovider/eventrule.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "context" "crypto/tls" "encoding/json" "errors" "fmt" "net/http" "path" "path/filepath" "slices" "strings" "time" "github.com/robfig/cron/v3" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) // Supported event actions const ( ActionTypeHTTP = iota + 1 ActionTypeCommand ActionTypeEmail ActionTypeBackup ActionTypeUserQuotaReset ActionTypeFolderQuotaReset ActionTypeTransferQuotaReset ActionTypeDataRetentionCheck ActionTypeFilesystem actionTypeReserved ActionTypePasswordExpirationCheck ActionTypeUserExpirationCheck ActionTypeIDPAccountCheck ActionTypeUserInactivityCheck ActionTypeRotateLogs ) var ( supportedEventActions = []int{ActionTypeHTTP, ActionTypeCommand, ActionTypeEmail, ActionTypeFilesystem, ActionTypeBackup, ActionTypeUserQuotaReset, ActionTypeFolderQuotaReset, ActionTypeTransferQuotaReset, ActionTypeDataRetentionCheck, ActionTypePasswordExpirationCheck, ActionTypeUserExpirationCheck, ActionTypeUserInactivityCheck, ActionTypeIDPAccountCheck, ActionTypeRotateLogs} // EnabledActionCommands defines the system commands that can be executed via EventManager, // an empty list means that no command is allowed to be executed. EnabledActionCommands []string ) func isActionTypeValid(action int) bool { return slices.Contains(supportedEventActions, action) } func getActionTypeAsString(action int) string { switch action { case ActionTypeHTTP: return util.I18nActionTypeHTTP case ActionTypeEmail: return util.I18nActionTypeEmail case ActionTypeBackup: return util.I18nActionTypeBackup case ActionTypeUserQuotaReset: return util.I18nActionTypeUserQuotaReset case ActionTypeFolderQuotaReset: return util.I18nActionTypeFolderQuotaReset case ActionTypeTransferQuotaReset: return util.I18nActionTypeTransferQuotaReset case ActionTypeDataRetentionCheck: return util.I18nActionTypeDataRetentionCheck case ActionTypeFilesystem: return util.I18nActionTypeFilesystem case ActionTypePasswordExpirationCheck: return util.I18nActionTypePwdExpirationCheck case ActionTypeUserExpirationCheck: return util.I18nActionTypeUserExpirationCheck case ActionTypeUserInactivityCheck: return util.I18nActionTypeUserInactivityCheck case ActionTypeIDPAccountCheck: return util.I18nActionTypeIDPCheck case ActionTypeRotateLogs: return util.I18nActionTypeRotateLogs default: return util.I18nActionTypeCommand } } // Supported event triggers const ( // Filesystem events such as upload, download, mkdir ... EventTriggerFsEvent = iota + 1 // Provider events such as add, update, delete EventTriggerProviderEvent EventTriggerSchedule EventTriggerIPBlocked EventTriggerCertificate EventTriggerOnDemand EventTriggerIDPLogin ) var ( supportedEventTriggers = []int{EventTriggerFsEvent, EventTriggerProviderEvent, EventTriggerSchedule, EventTriggerIPBlocked, EventTriggerCertificate, EventTriggerIDPLogin, EventTriggerOnDemand} ) func isEventTriggerValid(trigger int) bool { return slices.Contains(supportedEventTriggers, trigger) } func getTriggerTypeAsString(trigger int) string { switch trigger { case EventTriggerFsEvent: return util.I18nTriggerFsEvent case EventTriggerProviderEvent: return util.I18nTriggerProviderEvent case EventTriggerIPBlocked: return util.I18nTriggerIPBlockedEvent case EventTriggerCertificate: return util.I18nTriggerCertificateRenewEvent case EventTriggerOnDemand: return util.I18nTriggerOnDemandEvent case EventTriggerIDPLogin: return util.I18nTriggerIDPLoginEvent default: return util.I18nTriggerScheduleEvent } } // Supported IDP login events const ( IDPLoginAny = iota IDPLoginUser IDPLoginAdmin ) var ( supportedIDPLoginEvents = []int{IDPLoginAny, IDPLoginUser, IDPLoginAdmin} ) // Supported filesystem actions const ( FilesystemActionRename = iota + 1 FilesystemActionDelete FilesystemActionMkdirs FilesystemActionExist FilesystemActionCompress FilesystemActionCopy ) const ( // RetentionReportPlaceHolder defines the placeholder for data retention reports RetentionReportPlaceHolder = "{{RetentionReports}}" ) var ( supportedFsActions = []int{FilesystemActionRename, FilesystemActionDelete, FilesystemActionMkdirs, FilesystemActionCopy, FilesystemActionCompress, FilesystemActionExist} ) func isFilesystemActionValid(value int) bool { return slices.Contains(supportedFsActions, value) } func getFsActionTypeAsString(value int) string { switch value { case FilesystemActionRename: return util.I18nActionFsTypeRename case FilesystemActionDelete: return util.I18nActionFsTypeDelete case FilesystemActionExist: return util.I18nActionFsTypePathExists case FilesystemActionCompress: return util.I18nActionFsTypeCompress case FilesystemActionCopy: return util.I18nActionFsTypeCopy default: return util.I18nActionFsTypeCreateDirs } } // TODO: replace the copied strings with shared constants var ( // SupportedFsEvents defines the supported filesystem events SupportedFsEvents = []string{"upload", "pre-upload", "first-upload", "download", "pre-download", "first-download", "delete", "pre-delete", "rename", "mkdir", "rmdir", "copy", "ssh_cmd"} // SupportedProviderEvents defines the supported provider events SupportedProviderEvents = []string{operationAdd, operationUpdate, operationDelete} // SupportedRuleConditionProtocols defines the supported protcols for rule conditions SupportedRuleConditionProtocols = []string{"SFTP", "SCP", "SSH", "FTP", "DAV", "HTTP", "HTTPShare", "OIDC"} // SupporteRuleConditionProviderObjects defines the supported provider objects for rule conditions SupporteRuleConditionProviderObjects = []string{actionObjectUser, actionObjectFolder, actionObjectGroup, actionObjectAdmin, actionObjectAPIKey, actionObjectShare, actionObjectEventRule, actionObjectEventAction} // SupportedHTTPActionMethods defines the supported methods for HTTP actions SupportedHTTPActionMethods = []string{http.MethodPost, http.MethodGet, http.MethodPut, http.MethodDelete} allowedSyncFsEvents = []string{"upload", "pre-upload", "pre-download", "pre-delete"} mandatorySyncFsEvents = []string{"pre-upload", "pre-download", "pre-delete"} ) // enum mappings var ( EventActionTypes []EnumMapping EventTriggerTypes []EnumMapping FsActionTypes []EnumMapping ) func init() { for _, t := range supportedEventActions { EventActionTypes = append(EventActionTypes, EnumMapping{ Value: t, Name: getActionTypeAsString(t), }) } for _, t := range supportedEventTriggers { EventTriggerTypes = append(EventTriggerTypes, EnumMapping{ Value: t, Name: getTriggerTypeAsString(t), }) } for _, t := range supportedFsActions { FsActionTypes = append(FsActionTypes, EnumMapping{ Value: t, Name: getFsActionTypeAsString(t), }) } } // EnumMapping defines a mapping between enum values and names type EnumMapping struct { Name string Value int } // KeyValue defines a key/value pair type KeyValue struct { Key string `json:"key"` Value string `json:"value"` } func (k *KeyValue) isNotValid() bool { return k.Key == "" || k.Value == "" } // HTTPPart defines a part for HTTP multipart requests type HTTPPart struct { Name string `json:"name,omitempty"` Filepath string `json:"filepath,omitempty"` Headers []KeyValue `json:"headers,omitempty"` Body string `json:"body,omitempty"` Order int `json:"-"` } func (p *HTTPPart) validate() error { if p.Name == "" { return util.NewI18nError(util.NewValidationError("HTTP part name is required"), util.I18nErrorHTTPPartNameRequired) } for _, kv := range p.Headers { if kv.isNotValid() { return util.NewValidationError("invalid HTTP part headers") } } if p.Filepath == "" { if p.Body == "" { return util.NewI18nError( util.NewValidationError("HTTP part body is required if no file path is provided"), util.I18nErrorHTTPPartBodyRequired, ) } } else { p.Body = "" if p.Filepath != RetentionReportPlaceHolder { p.Filepath = util.CleanPath(p.Filepath) } } return nil } // EventActionHTTPConfig defines the configuration for an HTTP event target type EventActionHTTPConfig struct { Endpoint string `json:"endpoint,omitempty"` Username string `json:"username,omitempty"` Password *kms.Secret `json:"password,omitempty"` Headers []KeyValue `json:"headers,omitempty"` Timeout int `json:"timeout,omitempty"` SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` Method string `json:"method,omitempty"` QueryParameters []KeyValue `json:"query_parameters,omitempty"` Body string `json:"body,omitempty"` Parts []HTTPPart `json:"parts,omitempty"` } // HasJSONBody returns true if the content type header indicates a JSON body func (c *EventActionHTTPConfig) HasJSONBody() bool { for _, h := range c.Headers { if http.CanonicalHeaderKey(h.Key) == "Content-Type" { return strings.Contains(strings.ToLower(h.Value), "application/json") } } return false } func (c *EventActionHTTPConfig) isTimeoutNotValid() bool { if c.HasMultipartFiles() { return false } return c.Timeout < 1 || c.Timeout > 180 } func (c *EventActionHTTPConfig) validateMultiparts() error { filePaths := make(map[string]bool) for idx := range c.Parts { if err := c.Parts[idx].validate(); err != nil { return err } if filePath := c.Parts[idx].Filepath; filePath != "" { if filePaths[filePath] { return util.NewI18nError(fmt.Errorf("filepath %q is duplicated", filePath), util.I18nErrorPathDuplicated) } filePaths[filePath] = true } } if len(c.Parts) > 0 { if c.Body != "" { return util.NewI18nError( util.NewValidationError("multipart requests require no body. The request body is build from the specified parts"), util.I18nErrorMultipartBody, ) } for _, k := range c.Headers { if strings.EqualFold(k.Key, "content-type") { return util.NewI18nError( util.NewValidationError("content type is automatically set for multipart requests"), util.I18nErrorMultipartCType, ) } } } return nil } func (c *EventActionHTTPConfig) validate(additionalData string) error { if c.Endpoint == "" { return util.NewI18nError(util.NewValidationError("HTTP endpoint is required"), util.I18nErrorURLRequired) } if !util.IsStringPrefixInSlice(c.Endpoint, []string{"http://", "https://"}) { return util.NewI18nError( util.NewValidationError("invalid HTTP endpoint schema: http and https are supported"), util.I18nErrorURLInvalid, ) } if c.isTimeoutNotValid() { return util.NewValidationError(fmt.Sprintf("invalid HTTP timeout %d", c.Timeout)) } for _, kv := range c.Headers { if kv.isNotValid() { return util.NewValidationError("invalid HTTP headers") } } if err := c.validateMultiparts(); err != nil { return err } if c.Password.IsRedacted() { return util.NewValidationError("cannot save HTTP configuration with a redacted secret") } if c.Password.IsPlain() { c.Password.SetAdditionalData(additionalData) err := c.Password.Encrypt() if err != nil { return util.NewValidationError(fmt.Sprintf("could not encrypt HTTP password: %v", err)) } } if !slices.Contains(SupportedHTTPActionMethods, c.Method) { return util.NewValidationError(fmt.Sprintf("unsupported HTTP method: %s", c.Method)) } for _, kv := range c.QueryParameters { if kv.isNotValid() { return util.NewValidationError("invalid HTTP query parameters") } } return nil } // GetContext returns the context and the cancel func to use for the HTTP request func (c *EventActionHTTPConfig) GetContext() (context.Context, context.CancelFunc) { if c.HasMultipartFiles() { return context.WithCancel(context.Background()) } return context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second) } // HasObjectData returns true if the {{ObjectData}} placeholder is defined func (c *EventActionHTTPConfig) HasObjectData() bool { if strings.Contains(c.Body, "{{ObjectData}}") || strings.Contains(c.Body, "{{ObjectDataString}}") { return true } for _, part := range c.Parts { if strings.Contains(part.Body, "{{ObjectData}}") || strings.Contains(part.Body, "{{ObjectDataString}}") { return true } } return false } // HasMultipartFiles returns true if at least a file must be uploaded via a multipart request func (c *EventActionHTTPConfig) HasMultipartFiles() bool { for _, part := range c.Parts { if part.Filepath != "" && part.Filepath != RetentionReportPlaceHolder { return true } } return false } // TryDecryptPassword decrypts the password if encryptet func (c *EventActionHTTPConfig) TryDecryptPassword() error { if c.Password != nil && !c.Password.IsEmpty() { if err := c.Password.TryDecrypt(); err != nil { return fmt.Errorf("unable to decrypt HTTP password: %w", err) } } return nil } // GetHTTPClient returns an HTTP client based on the config func (c *EventActionHTTPConfig) GetHTTPClient() *http.Client { client := &http.Client{} if c.SkipTLSVerify { transport := http.DefaultTransport.(*http.Transport).Clone() if transport.TLSClientConfig != nil { transport.TLSClientConfig.InsecureSkipVerify = true } else { transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } } client.Transport = transport } return client } // IsActionCommandAllowed returns true if the specified command is allowed func IsActionCommandAllowed(cmd string) bool { return slices.Contains(EnabledActionCommands, cmd) } // EventActionCommandConfig defines the configuration for a command event target type EventActionCommandConfig struct { Cmd string `json:"cmd,omitempty"` Args []string `json:"args,omitempty"` Timeout int `json:"timeout,omitempty"` EnvVars []KeyValue `json:"env_vars,omitempty"` } func (c *EventActionCommandConfig) validate() error { if c.Cmd == "" { return util.NewI18nError(util.NewValidationError("command is required"), util.I18nErrorCommandRequired) } if !IsActionCommandAllowed(c.Cmd) { return util.NewValidationError(fmt.Sprintf("command %q is not allowed", c.Cmd)) } if !filepath.IsAbs(c.Cmd) { return util.NewI18nError( util.NewValidationError("invalid command, it must be an absolute path"), util.I18nErrorCommandInvalid, ) } if c.Timeout < 1 || c.Timeout > 120 { return util.NewValidationError(fmt.Sprintf("invalid command action timeout %d", c.Timeout)) } for _, kv := range c.EnvVars { if kv.isNotValid() { return util.NewValidationError("invalid command env vars") } } c.Args = util.RemoveDuplicates(c.Args, true) for _, arg := range c.Args { if arg == "" { return util.NewValidationError("invalid command args") } } return nil } // GetArgumentsAsString returns the list of command arguments as comma separated string func (c EventActionCommandConfig) GetArgumentsAsString() string { return strings.Join(c.Args, ",") } // EventActionEmailConfig defines the configuration options for SMTP event actions type EventActionEmailConfig struct { Recipients []string `json:"recipients,omitempty"` Bcc []string `json:"bcc,omitempty"` Subject string `json:"subject,omitempty"` Body string `json:"body,omitempty"` Attachments []string `json:"attachments,omitempty"` ContentType int `json:"content_type,omitempty"` } // GetRecipientsAsString returns the list of recipients as comma separated string func (c EventActionEmailConfig) GetRecipientsAsString() string { return strings.Join(c.Recipients, ",") } // GetBccAsString returns the list of bcc as comma separated string func (c EventActionEmailConfig) GetBccAsString() string { return strings.Join(c.Bcc, ",") } // GetAttachmentsAsString returns the list of attachments as comma separated string func (c EventActionEmailConfig) GetAttachmentsAsString() string { return strings.Join(c.Attachments, ",") } func (c *EventActionEmailConfig) hasFilesAttachments() bool { for _, a := range c.Attachments { if a != RetentionReportPlaceHolder { return true } } return false } func (c *EventActionEmailConfig) validate() error { if len(c.Recipients) == 0 { return util.NewI18nError( util.NewValidationError("at least one email recipient is required"), util.I18nErrorEmailRecipientRequired, ) } c.Recipients = util.RemoveDuplicates(c.Recipients, false) for _, r := range c.Recipients { if r == "" { return util.NewValidationError("invalid email recipients") } } c.Bcc = util.RemoveDuplicates(c.Bcc, false) for _, r := range c.Bcc { if r == "" { return util.NewValidationError("invalid email bcc") } } if c.Subject == "" { return util.NewI18nError( util.NewValidationError("email subject is required"), util.I18nErrorEmailSubjectRequired, ) } if c.Body == "" { return util.NewI18nError( util.NewValidationError("email body is required"), util.I18nErrorEmailBodyRequired, ) } if c.ContentType < 0 || c.ContentType > 1 { return util.NewValidationError("invalid email content type") } for idx, val := range c.Attachments { val = strings.TrimSpace(val) if val == "" { return util.NewValidationError("invalid path to attach") } if val == RetentionReportPlaceHolder { c.Attachments[idx] = val } else { c.Attachments[idx] = util.CleanPath(val) } } c.Attachments = util.RemoveDuplicates(c.Attachments, false) return nil } // FolderRetention defines a folder retention configuration type FolderRetention struct { // Path is the virtual directory path, if no other specific retention is defined, // the retention applies for sub directories too. For example if retention is defined // for the paths "/" and "/sub" then the retention for "/" is applied for any file outside // the "/sub" directory Path string `json:"path"` // Retention time in hours. 0 means exclude this path Retention int `json:"retention"` // DeleteEmptyDirs defines if empty directories will be deleted. // The user need the delete permission DeleteEmptyDirs bool `json:"delete_empty_dirs,omitempty"` } // Validate returns an error if the configuration is not valid func (f *FolderRetention) Validate() error { f.Path = util.CleanPath(f.Path) if f.Retention < 0 { return util.NewValidationError(fmt.Sprintf("invalid folder retention %v, it must be greater or equal to zero", f.Retention)) } return nil } // EventActionDataRetentionConfig defines the configuration for a data retention check type EventActionDataRetentionConfig struct { Folders []FolderRetention `json:"folders,omitempty"` } func (c *EventActionDataRetentionConfig) validate() error { folderPaths := make(map[string]bool) nothingToDo := true for idx := range c.Folders { f := &c.Folders[idx] if err := f.Validate(); err != nil { return err } if f.Retention > 0 { nothingToDo = false } if _, ok := folderPaths[f.Path]; ok { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("duplicated folder path %q", f.Path)), util.I18nErrorPathDuplicated, ) } folderPaths[f.Path] = true } if nothingToDo { return util.NewI18nError( util.NewValidationError("nothing to delete!"), util.I18nErrorRetentionDirRequired, ) } return nil } // EventActionFsCompress defines the configuration for the compress filesystem action type EventActionFsCompress struct { // Archive path Name string `json:"name,omitempty"` // Paths to compress Paths []string `json:"paths,omitempty"` } func (c *EventActionFsCompress) validate() error { if c.Name == "" { return util.NewI18nError(util.NewValidationError("archive name is mandatory"), util.I18nErrorArchiveNameRequired) } c.Name = util.CleanPath(strings.TrimSpace(c.Name)) if c.Name == "/" { return util.NewI18nError(util.NewValidationError("invalid archive name"), util.I18nErrorRootNotAllowed) } if len(c.Paths) == 0 { return util.NewI18nError(util.NewValidationError("no path to compress specified"), util.I18nErrorPathRequired) } for idx, val := range c.Paths { val = strings.TrimSpace(val) if val == "" { return util.NewValidationError("invalid path to compress") } c.Paths[idx] = util.CleanPath(val) } c.Paths = util.RemoveDuplicates(c.Paths, false) return nil } // RenameConfig defines the configuration for a filesystem rename type RenameConfig struct { // key is the source and target the value KeyValue // This setting only applies to storage providers that support // changing modification times. UpdateModTime bool `json:"update_modtime,omitempty"` } // EventActionFilesystemConfig defines the configuration for filesystem actions type EventActionFilesystemConfig struct { // Filesystem actions, see the above enum Type int `json:"type,omitempty"` // files/dirs to rename Renames []RenameConfig `json:"renames,omitempty"` // directories to create MkDirs []string `json:"mkdirs,omitempty"` // files/dirs to delete Deletes []string `json:"deletes,omitempty"` // file/dirs to check for existence Exist []string `json:"exist,omitempty"` // files/dirs to copy, key is the source and target the value Copy []KeyValue `json:"copy,omitempty"` // paths to compress and archive name Compress EventActionFsCompress `json:"compress"` } // GetDeletesAsString returns the list of items to delete as comma separated string. // Using a pointer receiver will not work in web templates func (c EventActionFilesystemConfig) GetDeletesAsString() string { return strings.Join(c.Deletes, ",") } // GetMkDirsAsString returns the list of directories to create as comma separated string. // Using a pointer receiver will not work in web templates func (c EventActionFilesystemConfig) GetMkDirsAsString() string { return strings.Join(c.MkDirs, ",") } // GetExistAsString returns the list of items to check for existence as comma separated string. // Using a pointer receiver will not work in web templates func (c EventActionFilesystemConfig) GetExistAsString() string { return strings.Join(c.Exist, ",") } // GetCompressPathsAsString returns the list of items to compress as comma separated string. // Using a pointer receiver will not work in web templates func (c EventActionFilesystemConfig) GetCompressPathsAsString() string { return strings.Join(c.Compress.Paths, ",") } func (c *EventActionFilesystemConfig) validateRenames() error { if len(c.Renames) == 0 { return util.NewI18nError(util.NewValidationError("no path to rename specified"), util.I18nErrorPathRequired) } for idx, cfg := range c.Renames { key := strings.TrimSpace(cfg.Key) value := strings.TrimSpace(cfg.Value) if key == "" || value == "" { return util.NewValidationError("invalid paths to rename") } key = util.CleanPath(key) value = util.CleanPath(value) if key == value { return util.NewI18nError( util.NewValidationError("rename source and target cannot be equal"), util.I18nErrorSourceDestMatch, ) } if key == "/" || value == "/" { return util.NewI18nError( util.NewValidationError("renaming the root directory is not allowed"), util.I18nErrorRootNotAllowed, ) } c.Renames[idx] = RenameConfig{ KeyValue: KeyValue{ Key: key, Value: value, }, UpdateModTime: cfg.UpdateModTime, } } return nil } func (c *EventActionFilesystemConfig) validateCopy() error { if len(c.Copy) == 0 { return util.NewI18nError(util.NewValidationError("no path to copy specified"), util.I18nErrorPathRequired) } for idx, kv := range c.Copy { key := strings.TrimSpace(kv.Key) value := strings.TrimSpace(kv.Value) if key == "" || value == "" { return util.NewValidationError("invalid paths to copy") } key = util.CleanPath(key) value = util.CleanPath(value) if key == value { return util.NewI18nError( util.NewValidationError("copy source and target cannot be equal"), util.I18nErrorSourceDestMatch, ) } if key == "/" || value == "/" { return util.NewI18nError( util.NewValidationError("copying the root directory is not allowed"), util.I18nErrorRootNotAllowed, ) } if strings.HasSuffix(c.Copy[idx].Key, "/") { key += "/" } if strings.HasSuffix(c.Copy[idx].Value, "/") { value += "/" } c.Copy[idx] = KeyValue{ Key: key, Value: value, } } return nil } func (c *EventActionFilesystemConfig) validateDeletes() error { if len(c.Deletes) == 0 { return util.NewI18nError(util.NewValidationError("no path to delete specified"), util.I18nErrorPathRequired) } for idx, val := range c.Deletes { val = strings.TrimSpace(val) if val == "" { return util.NewValidationError("invalid path to delete") } c.Deletes[idx] = util.CleanPath(val) } c.Deletes = util.RemoveDuplicates(c.Deletes, false) return nil } func (c *EventActionFilesystemConfig) validateMkdirs() error { if len(c.MkDirs) == 0 { return util.NewI18nError(util.NewValidationError("no directory to create specified"), util.I18nErrorPathRequired) } for idx, val := range c.MkDirs { val = strings.TrimSpace(val) if val == "" { return util.NewValidationError("invalid directory to create") } c.MkDirs[idx] = util.CleanPath(val) } c.MkDirs = util.RemoveDuplicates(c.MkDirs, false) return nil } func (c *EventActionFilesystemConfig) validateExist() error { if len(c.Exist) == 0 { return util.NewI18nError(util.NewValidationError("no path to check for existence specified"), util.I18nErrorPathRequired) } for idx, val := range c.Exist { val = strings.TrimSpace(val) if val == "" { return util.NewValidationError("invalid path to check for existence") } c.Exist[idx] = util.CleanPath(val) } c.Exist = util.RemoveDuplicates(c.Exist, false) return nil } func (c *EventActionFilesystemConfig) validate() error { if !isFilesystemActionValid(c.Type) { return util.NewValidationError(fmt.Sprintf("invalid filesystem action type: %d", c.Type)) } switch c.Type { case FilesystemActionRename: c.MkDirs = nil c.Deletes = nil c.Exist = nil c.Copy = nil c.Compress = EventActionFsCompress{} if err := c.validateRenames(); err != nil { return err } case FilesystemActionDelete: c.Renames = nil c.MkDirs = nil c.Exist = nil c.Copy = nil c.Compress = EventActionFsCompress{} if err := c.validateDeletes(); err != nil { return err } case FilesystemActionMkdirs: c.Renames = nil c.Deletes = nil c.Exist = nil c.Copy = nil c.Compress = EventActionFsCompress{} if err := c.validateMkdirs(); err != nil { return err } case FilesystemActionExist: c.Renames = nil c.Deletes = nil c.MkDirs = nil c.Copy = nil c.Compress = EventActionFsCompress{} if err := c.validateExist(); err != nil { return err } case FilesystemActionCompress: c.Renames = nil c.MkDirs = nil c.Deletes = nil c.Exist = nil c.Copy = nil if err := c.Compress.validate(); err != nil { return err } case FilesystemActionCopy: c.Renames = nil c.Deletes = nil c.MkDirs = nil c.Exist = nil c.Compress = EventActionFsCompress{} if err := c.validateCopy(); err != nil { return err } } return nil } func (c *EventActionFilesystemConfig) getACopy() EventActionFilesystemConfig { mkdirs := make([]string, len(c.MkDirs)) copy(mkdirs, c.MkDirs) deletes := make([]string, len(c.Deletes)) copy(deletes, c.Deletes) exist := make([]string, len(c.Exist)) copy(exist, c.Exist) compressPaths := make([]string, len(c.Compress.Paths)) copy(compressPaths, c.Compress.Paths) return EventActionFilesystemConfig{ Type: c.Type, Renames: cloneRenameConfigs(c.Renames), MkDirs: mkdirs, Deletes: deletes, Exist: exist, Copy: cloneKeyValues(c.Copy), Compress: EventActionFsCompress{ Paths: compressPaths, Name: c.Compress.Name, }, } } // EventActionPasswordExpiration defines the configuration for password expiration actions type EventActionPasswordExpiration struct { // An email notification will be generated for users whose password expires in a number // of days less than or equal to this threshold Threshold int `json:"threshold,omitempty"` } func (c *EventActionPasswordExpiration) validate() error { if c.Threshold <= 0 { return util.NewValidationError("threshold must be greater than 0") } return nil } // EventActionUserInactivity defines the configuration for user inactivity checks. type EventActionUserInactivity struct { // DisableThreshold defines inactivity in days, since the last login before disabling the account DisableThreshold int `json:"disable_threshold,omitempty"` // DeleteThreshold defines inactivity in days, since the last login before deleting the account DeleteThreshold int `json:"delete_threshold,omitempty"` } func (c *EventActionUserInactivity) validate() error { if c.DeleteThreshold < 0 { c.DeleteThreshold = 0 } if c.DisableThreshold < 0 { c.DisableThreshold = 0 } if c.DisableThreshold == 0 && c.DeleteThreshold == 0 { return util.NewI18nError( util.NewValidationError("at least a threshold must be defined"), util.I18nActionThresholdRequired, ) } if c.DeleteThreshold > 0 && c.DisableThreshold > 0 { if c.DeleteThreshold <= c.DisableThreshold { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("deletion threshold %d must be greater than deactivation threshold: %d", c.DeleteThreshold, c.DisableThreshold)), util.I18nActionThresholdsInvalid, ) } } return nil } // EventActionIDPAccountCheck defines the check to execute after a successful IDP login type EventActionIDPAccountCheck struct { // 0 create/update, 1 create the account if it doesn't exist Mode int `json:"mode,omitempty"` TemplateUser string `json:"template_user,omitempty"` TemplateAdmin string `json:"template_admin,omitempty"` } func (c *EventActionIDPAccountCheck) validate() error { if c.TemplateAdmin == "" && c.TemplateUser == "" { return util.NewI18nError( util.NewValidationError("at least a template must be set"), util.I18nErrorIDPTemplateRequired, ) } if c.Mode < 0 || c.Mode > 1 { return util.NewValidationError(fmt.Sprintf("invalid account check mode: %d", c.Mode)) } return nil } // BaseEventActionOptions defines the supported configuration options for a base event actions type BaseEventActionOptions struct { HTTPConfig EventActionHTTPConfig `json:"http_config"` CmdConfig EventActionCommandConfig `json:"cmd_config"` EmailConfig EventActionEmailConfig `json:"email_config"` RetentionConfig EventActionDataRetentionConfig `json:"retention_config"` FsConfig EventActionFilesystemConfig `json:"fs_config"` PwdExpirationConfig EventActionPasswordExpiration `json:"pwd_expiration_config"` UserInactivityConfig EventActionUserInactivity `json:"user_inactivity_config"` IDPConfig EventActionIDPAccountCheck `json:"idp_config"` } func (o *BaseEventActionOptions) getACopy() BaseEventActionOptions { o.SetEmptySecretsIfNil() emailRecipients := make([]string, len(o.EmailConfig.Recipients)) copy(emailRecipients, o.EmailConfig.Recipients) emailBcc := make([]string, len(o.EmailConfig.Bcc)) copy(emailBcc, o.EmailConfig.Bcc) emailAttachments := make([]string, len(o.EmailConfig.Attachments)) copy(emailAttachments, o.EmailConfig.Attachments) cmdArgs := make([]string, len(o.CmdConfig.Args)) copy(cmdArgs, o.CmdConfig.Args) folders := make([]FolderRetention, 0, len(o.RetentionConfig.Folders)) for _, folder := range o.RetentionConfig.Folders { folders = append(folders, FolderRetention{ Path: folder.Path, Retention: folder.Retention, DeleteEmptyDirs: folder.DeleteEmptyDirs, }) } httpParts := make([]HTTPPart, 0, len(o.HTTPConfig.Parts)) for _, part := range o.HTTPConfig.Parts { httpParts = append(httpParts, HTTPPart{ Name: part.Name, Filepath: part.Filepath, Headers: cloneKeyValues(part.Headers), Body: part.Body, }) } return BaseEventActionOptions{ HTTPConfig: EventActionHTTPConfig{ Endpoint: o.HTTPConfig.Endpoint, Username: o.HTTPConfig.Username, Password: o.HTTPConfig.Password.Clone(), Headers: cloneKeyValues(o.HTTPConfig.Headers), Timeout: o.HTTPConfig.Timeout, SkipTLSVerify: o.HTTPConfig.SkipTLSVerify, Method: o.HTTPConfig.Method, QueryParameters: cloneKeyValues(o.HTTPConfig.QueryParameters), Body: o.HTTPConfig.Body, Parts: httpParts, }, CmdConfig: EventActionCommandConfig{ Cmd: o.CmdConfig.Cmd, Args: cmdArgs, Timeout: o.CmdConfig.Timeout, EnvVars: cloneKeyValues(o.CmdConfig.EnvVars), }, EmailConfig: EventActionEmailConfig{ Recipients: emailRecipients, Bcc: emailBcc, Subject: o.EmailConfig.Subject, ContentType: o.EmailConfig.ContentType, Body: o.EmailConfig.Body, Attachments: emailAttachments, }, RetentionConfig: EventActionDataRetentionConfig{ Folders: folders, }, PwdExpirationConfig: EventActionPasswordExpiration{ Threshold: o.PwdExpirationConfig.Threshold, }, UserInactivityConfig: EventActionUserInactivity{ DisableThreshold: o.UserInactivityConfig.DisableThreshold, DeleteThreshold: o.UserInactivityConfig.DeleteThreshold, }, IDPConfig: EventActionIDPAccountCheck{ Mode: o.IDPConfig.Mode, TemplateUser: o.IDPConfig.TemplateUser, TemplateAdmin: o.IDPConfig.TemplateAdmin, }, FsConfig: o.FsConfig.getACopy(), } } // SetEmptySecretsIfNil sets the secrets to empty if nil func (o *BaseEventActionOptions) SetEmptySecretsIfNil() { if o.HTTPConfig.Password == nil { o.HTTPConfig.Password = kms.NewEmptySecret() } } func (o *BaseEventActionOptions) setNilSecretsIfEmpty() { if o.HTTPConfig.Password != nil && o.HTTPConfig.Password.IsEmpty() { o.HTTPConfig.Password = nil } } func (o *BaseEventActionOptions) hideConfidentialData() { if o.HTTPConfig.Password != nil { o.HTTPConfig.Password.Hide() } } func (o *BaseEventActionOptions) validate(action int, name string) error { o.SetEmptySecretsIfNil() switch action { case ActionTypeHTTP: o.CmdConfig = EventActionCommandConfig{} o.EmailConfig = EventActionEmailConfig{} o.RetentionConfig = EventActionDataRetentionConfig{} o.FsConfig = EventActionFilesystemConfig{} o.PwdExpirationConfig = EventActionPasswordExpiration{} o.IDPConfig = EventActionIDPAccountCheck{} o.UserInactivityConfig = EventActionUserInactivity{} return o.HTTPConfig.validate(name) case ActionTypeCommand: o.HTTPConfig = EventActionHTTPConfig{} o.EmailConfig = EventActionEmailConfig{} o.RetentionConfig = EventActionDataRetentionConfig{} o.FsConfig = EventActionFilesystemConfig{} o.PwdExpirationConfig = EventActionPasswordExpiration{} o.IDPConfig = EventActionIDPAccountCheck{} o.UserInactivityConfig = EventActionUserInactivity{} return o.CmdConfig.validate() case ActionTypeEmail: o.HTTPConfig = EventActionHTTPConfig{} o.CmdConfig = EventActionCommandConfig{} o.RetentionConfig = EventActionDataRetentionConfig{} o.FsConfig = EventActionFilesystemConfig{} o.PwdExpirationConfig = EventActionPasswordExpiration{} o.IDPConfig = EventActionIDPAccountCheck{} o.UserInactivityConfig = EventActionUserInactivity{} return o.EmailConfig.validate() case ActionTypeDataRetentionCheck: o.HTTPConfig = EventActionHTTPConfig{} o.CmdConfig = EventActionCommandConfig{} o.EmailConfig = EventActionEmailConfig{} o.FsConfig = EventActionFilesystemConfig{} o.PwdExpirationConfig = EventActionPasswordExpiration{} o.IDPConfig = EventActionIDPAccountCheck{} o.UserInactivityConfig = EventActionUserInactivity{} return o.RetentionConfig.validate() case ActionTypeFilesystem: o.HTTPConfig = EventActionHTTPConfig{} o.CmdConfig = EventActionCommandConfig{} o.EmailConfig = EventActionEmailConfig{} o.RetentionConfig = EventActionDataRetentionConfig{} o.PwdExpirationConfig = EventActionPasswordExpiration{} o.IDPConfig = EventActionIDPAccountCheck{} o.UserInactivityConfig = EventActionUserInactivity{} return o.FsConfig.validate() case ActionTypePasswordExpirationCheck: o.HTTPConfig = EventActionHTTPConfig{} o.CmdConfig = EventActionCommandConfig{} o.EmailConfig = EventActionEmailConfig{} o.RetentionConfig = EventActionDataRetentionConfig{} o.FsConfig = EventActionFilesystemConfig{} o.IDPConfig = EventActionIDPAccountCheck{} o.UserInactivityConfig = EventActionUserInactivity{} return o.PwdExpirationConfig.validate() case ActionTypeUserInactivityCheck: o.HTTPConfig = EventActionHTTPConfig{} o.CmdConfig = EventActionCommandConfig{} o.EmailConfig = EventActionEmailConfig{} o.RetentionConfig = EventActionDataRetentionConfig{} o.FsConfig = EventActionFilesystemConfig{} o.IDPConfig = EventActionIDPAccountCheck{} o.PwdExpirationConfig = EventActionPasswordExpiration{} return o.UserInactivityConfig.validate() case ActionTypeIDPAccountCheck: o.HTTPConfig = EventActionHTTPConfig{} o.CmdConfig = EventActionCommandConfig{} o.EmailConfig = EventActionEmailConfig{} o.RetentionConfig = EventActionDataRetentionConfig{} o.FsConfig = EventActionFilesystemConfig{} o.PwdExpirationConfig = EventActionPasswordExpiration{} o.UserInactivityConfig = EventActionUserInactivity{} return o.IDPConfig.validate() default: o.HTTPConfig = EventActionHTTPConfig{} o.CmdConfig = EventActionCommandConfig{} o.EmailConfig = EventActionEmailConfig{} o.RetentionConfig = EventActionDataRetentionConfig{} o.FsConfig = EventActionFilesystemConfig{} o.PwdExpirationConfig = EventActionPasswordExpiration{} o.IDPConfig = EventActionIDPAccountCheck{} o.UserInactivityConfig = EventActionUserInactivity{} } return nil } // BaseEventAction defines the common fields for an event action type BaseEventAction struct { // Data provider unique identifier ID int64 `json:"id"` // Action name Name string `json:"name"` // optional description Description string `json:"description,omitempty"` // ActionType, see the above enum Type int `json:"type"` // Configuration options specific for the action type Options BaseEventActionOptions `json:"options"` // list of rule names associated with this event action Rules []string `json:"rules,omitempty"` } func (a *BaseEventAction) getACopy() BaseEventAction { rules := make([]string, len(a.Rules)) copy(rules, a.Rules) return BaseEventAction{ ID: a.ID, Name: a.Name, Description: a.Description, Type: a.Type, Options: a.Options.getACopy(), Rules: rules, } } // GetTypeAsString returns the action type as string func (a *BaseEventAction) GetTypeAsString() string { return getActionTypeAsString(a.Type) } // GetRulesAsString returns the list of rules as comma separated string func (a *BaseEventAction) GetRulesAsString() string { return strings.Join(a.Rules, ",") } // PrepareForRendering prepares a BaseEventAction for rendering. // It hides confidential data and set to nil the empty secrets // so they are not serialized func (a *BaseEventAction) PrepareForRendering() { a.Options.setNilSecretsIfEmpty() a.Options.hideConfidentialData() } // RenderAsJSON implements the renderer interface used within plugins func (a *BaseEventAction) RenderAsJSON(reload bool) ([]byte, error) { if reload { action, err := provider.eventActionExists(a.Name) if err != nil { providerLog(logger.LevelError, "unable to reload event action before rendering as json: %v", err) return nil, err } action.PrepareForRendering() return json.Marshal(action) } a.PrepareForRendering() return json.Marshal(a) } func (a *BaseEventAction) validate() error { if a.Name == "" { return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) } if !util.IsNameValid(a.Name) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if config.NamingRules&1 == 0 && !usernameRegex.MatchString(a.Name) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", a.Name)), util.I18nErrorInvalidUser, ) } if !isActionTypeValid(a.Type) { return util.NewValidationError(fmt.Sprintf("invalid action type: %d", a.Type)) } return a.Options.validate(a.Type, a.Name) } // EventActionOptions defines the supported configuration options for an event action type EventActionOptions struct { IsFailureAction bool `json:"is_failure_action"` StopOnFailure bool `json:"stop_on_failure"` ExecuteSync bool `json:"execute_sync"` } // EventAction defines an event action type EventAction struct { BaseEventAction // Order defines the execution order Order int `json:"order,omitempty"` Options EventActionOptions `json:"relation_options"` } func (a *EventAction) getACopy() EventAction { return EventAction{ BaseEventAction: a.BaseEventAction.getACopy(), Order: a.Order, Options: EventActionOptions{ IsFailureAction: a.Options.IsFailureAction, StopOnFailure: a.Options.StopOnFailure, ExecuteSync: a.Options.ExecuteSync, }, } } func (a *EventAction) validateAssociation(trigger int, fsEvents []string) error { if a.Options.IsFailureAction { if a.Options.ExecuteSync { return util.NewI18nError( util.NewValidationError("sync execution is not supported for failure actions"), util.I18nErrorEvSyncFailureActions, ) } } if a.Options.ExecuteSync { if trigger != EventTriggerFsEvent && trigger != EventTriggerIDPLogin { return util.NewI18nError( util.NewValidationError("sync execution is only supported for some filesystem events and Identity Provider logins"), util.I18nErrorEvSyncUnsupported, ) } if trigger == EventTriggerFsEvent { for _, ev := range fsEvents { if !slices.Contains(allowedSyncFsEvents, ev) { return util.NewI18nError( util.NewValidationError("sync execution is only supported for upload and pre-* events"), util.I18nErrorEvSyncUnsupportedFs, ) } } } } return nil } // ConditionPattern defines a pattern for condition filters type ConditionPattern struct { Pattern string `json:"pattern,omitempty"` InverseMatch bool `json:"inverse_match,omitempty"` } func (p *ConditionPattern) validate() error { if p.Pattern == "" { return util.NewValidationError("empty condition pattern not allowed") } _, err := path.Match(p.Pattern, "abc") if err != nil { return util.NewValidationError(fmt.Sprintf("invalid condition pattern %q", p.Pattern)) } return nil } // ConditionOptions defines options for event conditions type ConditionOptions struct { // Usernames or folder names Names []ConditionPattern `json:"names,omitempty"` // Group names GroupNames []ConditionPattern `json:"group_names,omitempty"` // Role names RoleNames []ConditionPattern `json:"role_names,omitempty"` // Virtual paths FsPaths []ConditionPattern `json:"fs_paths,omitempty"` Protocols []string `json:"protocols,omitempty"` ProviderObjects []string `json:"provider_objects,omitempty"` MinFileSize int64 `json:"min_size,omitempty"` MaxFileSize int64 `json:"max_size,omitempty"` EventStatuses []int `json:"event_statuses,omitempty"` // allow to execute scheduled tasks concurrently from multiple instances ConcurrentExecution bool `json:"concurrent_execution,omitempty"` } func (f *ConditionOptions) getACopy() ConditionOptions { protocols := make([]string, len(f.Protocols)) copy(protocols, f.Protocols) providerObjects := make([]string, len(f.ProviderObjects)) copy(providerObjects, f.ProviderObjects) statuses := make([]int, len(f.EventStatuses)) copy(statuses, f.EventStatuses) return ConditionOptions{ Names: cloneConditionPatterns(f.Names), GroupNames: cloneConditionPatterns(f.GroupNames), RoleNames: cloneConditionPatterns(f.RoleNames), FsPaths: cloneConditionPatterns(f.FsPaths), Protocols: protocols, ProviderObjects: providerObjects, MinFileSize: f.MinFileSize, MaxFileSize: f.MaxFileSize, EventStatuses: statuses, ConcurrentExecution: f.ConcurrentExecution, } } func (f *ConditionOptions) validateStatuses() error { for _, status := range f.EventStatuses { if status < 0 || status > 3 { return util.NewValidationError(fmt.Sprintf("invalid event_status %d", status)) } } return nil } func (f *ConditionOptions) validate() error { if err := validateConditionPatterns(f.Names); err != nil { return err } if err := validateConditionPatterns(f.GroupNames); err != nil { return err } if err := validateConditionPatterns(f.RoleNames); err != nil { return err } if err := validateConditionPatterns(f.FsPaths); err != nil { return err } for _, p := range f.Protocols { if !slices.Contains(SupportedRuleConditionProtocols, p) { return util.NewValidationError(fmt.Sprintf("unsupported rule condition protocol: %q", p)) } } for _, p := range f.ProviderObjects { if !slices.Contains(SupporteRuleConditionProviderObjects, p) { return util.NewValidationError(fmt.Sprintf("unsupported provider object: %q", p)) } } if f.MinFileSize > 0 && f.MaxFileSize > 0 { if f.MaxFileSize <= f.MinFileSize { return util.NewValidationError(fmt.Sprintf("invalid max file size %s, it is lesser or equal than min file size %s", util.ByteCountSI(f.MaxFileSize), util.ByteCountSI(f.MinFileSize))) } } if err := f.validateStatuses(); err != nil { return err } if config.IsShared == 0 { f.ConcurrentExecution = false } return nil } // Schedule defines an event schedule type Schedule struct { Hours string `json:"hour"` DayOfWeek string `json:"day_of_week"` DayOfMonth string `json:"day_of_month"` Month string `json:"month"` } // GetCronSpec returns the cron compatible schedule string func (s *Schedule) GetCronSpec() string { return fmt.Sprintf("0 %s %s %s %s", s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek) } func (s *Schedule) validate() error { _, err := cron.ParseStandard(s.GetCronSpec()) if err != nil { return util.NewValidationError(fmt.Sprintf("invalid schedule, hour: %q, day of month: %q, month: %q, day of week: %q", s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek)) } return nil } // EventConditions defines the conditions for an event rule type EventConditions struct { // Only one between FsEvents, ProviderEvents and Schedule is allowed FsEvents []string `json:"fs_events,omitempty"` ProviderEvents []string `json:"provider_events,omitempty"` Schedules []Schedule `json:"schedules,omitempty"` // 0 any, 1 user, 2 admin IDPLoginEvent int `json:"idp_login_event,omitempty"` Options ConditionOptions `json:"options"` } func (c *EventConditions) getACopy() EventConditions { fsEvents := make([]string, len(c.FsEvents)) copy(fsEvents, c.FsEvents) providerEvents := make([]string, len(c.ProviderEvents)) copy(providerEvents, c.ProviderEvents) schedules := make([]Schedule, 0, len(c.Schedules)) for _, schedule := range c.Schedules { schedules = append(schedules, Schedule{ Hours: schedule.Hours, DayOfWeek: schedule.DayOfWeek, DayOfMonth: schedule.DayOfMonth, Month: schedule.Month, }) } return EventConditions{ FsEvents: fsEvents, ProviderEvents: providerEvents, Schedules: schedules, IDPLoginEvent: c.IDPLoginEvent, Options: c.Options.getACopy(), } } func (c *EventConditions) validateSchedules() error { if len(c.Schedules) == 0 { return util.NewI18nError( util.NewValidationError("at least one schedule is required"), util.I18nErrorRuleScheduleRequired, ) } for _, schedule := range c.Schedules { if err := schedule.validate(); err != nil { return util.NewI18nError(err, util.I18nErrorRuleScheduleInvalid) } } return nil } func (c *EventConditions) validate(trigger int) error { switch trigger { case EventTriggerFsEvent: c.ProviderEvents = nil c.Schedules = nil c.Options.ProviderObjects = nil c.IDPLoginEvent = 0 if len(c.FsEvents) == 0 { return util.NewI18nError( util.NewValidationError("at least one filesystem event is required"), util.I18nErrorRuleFsEventRequired, ) } for _, ev := range c.FsEvents { if !slices.Contains(SupportedFsEvents, ev) { return util.NewValidationError(fmt.Sprintf("unsupported fs event: %q", ev)) } } case EventTriggerProviderEvent: c.FsEvents = nil c.Schedules = nil c.Options.FsPaths = nil c.Options.Protocols = nil c.Options.EventStatuses = nil c.Options.MinFileSize = 0 c.Options.MaxFileSize = 0 c.IDPLoginEvent = 0 if len(c.ProviderEvents) == 0 { return util.NewI18nError( util.NewValidationError("at least one provider event is required"), util.I18nErrorRuleProviderEventRequired, ) } for _, ev := range c.ProviderEvents { if !slices.Contains(SupportedProviderEvents, ev) { return util.NewValidationError(fmt.Sprintf("unsupported provider event: %q", ev)) } } case EventTriggerSchedule: c.FsEvents = nil c.ProviderEvents = nil c.Options.FsPaths = nil c.Options.Protocols = nil c.Options.EventStatuses = nil c.Options.MinFileSize = 0 c.Options.MaxFileSize = 0 c.Options.ProviderObjects = nil c.IDPLoginEvent = 0 if err := c.validateSchedules(); err != nil { return err } case EventTriggerIPBlocked, EventTriggerCertificate: c.FsEvents = nil c.ProviderEvents = nil c.Options.Names = nil c.Options.GroupNames = nil c.Options.RoleNames = nil c.Options.FsPaths = nil c.Options.Protocols = nil c.Options.EventStatuses = nil c.Options.MinFileSize = 0 c.Options.MaxFileSize = 0 c.Schedules = nil c.IDPLoginEvent = 0 case EventTriggerOnDemand: c.FsEvents = nil c.ProviderEvents = nil c.Options.FsPaths = nil c.Options.Protocols = nil c.Options.EventStatuses = nil c.Options.MinFileSize = 0 c.Options.MaxFileSize = 0 c.Options.ProviderObjects = nil c.Schedules = nil c.IDPLoginEvent = 0 c.Options.ConcurrentExecution = false case EventTriggerIDPLogin: c.FsEvents = nil c.ProviderEvents = nil c.Options.GroupNames = nil c.Options.RoleNames = nil c.Options.FsPaths = nil c.Options.Protocols = nil c.Options.EventStatuses = nil c.Options.MinFileSize = 0 c.Options.MaxFileSize = 0 c.Schedules = nil if !slices.Contains(supportedIDPLoginEvents, c.IDPLoginEvent) { return util.NewValidationError(fmt.Sprintf("invalid Identity Provider login event %d", c.IDPLoginEvent)) } default: c.FsEvents = nil c.ProviderEvents = nil c.Options.GroupNames = nil c.Options.RoleNames = nil c.Options.FsPaths = nil c.Options.Protocols = nil c.Options.EventStatuses = nil c.Options.MinFileSize = 0 c.Options.MaxFileSize = 0 c.Schedules = nil c.IDPLoginEvent = 0 } return c.Options.validate() } // EventRule defines the trigger, conditions and actions for an event type EventRule struct { // Data provider unique identifier ID int64 `json:"id"` // Rule name Name string `json:"name"` // 1 enabled, 0 disabled Status int `json:"status"` // optional description Description string `json:"description,omitempty"` // Creation time as unix timestamp in milliseconds CreatedAt int64 `json:"created_at"` // last update time as unix timestamp in milliseconds UpdatedAt int64 `json:"updated_at"` // Event trigger Trigger int `json:"trigger"` // Event conditions Conditions EventConditions `json:"conditions"` // actions to execute Actions []EventAction `json:"actions"` // in multi node setups we mark the rule as deleted to be able to update the cache DeletedAt int64 `json:"-"` } func (r *EventRule) getACopy() EventRule { actions := make([]EventAction, 0, len(r.Actions)) for _, action := range r.Actions { actions = append(actions, action.getACopy()) } return EventRule{ ID: r.ID, Name: r.Name, Status: r.Status, Description: r.Description, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, Trigger: r.Trigger, Conditions: r.Conditions.getACopy(), Actions: actions, DeletedAt: r.DeletedAt, } } // GuardFromConcurrentExecution returns true if the rule cannot be executed concurrently // from multiple instances func (r *EventRule) GuardFromConcurrentExecution() bool { if config.IsShared == 0 { return false } return !r.Conditions.Options.ConcurrentExecution } // GetTriggerAsString returns the rule trigger as string func (r *EventRule) GetTriggerAsString() string { return getTriggerTypeAsString(r.Trigger) } // GetActionsAsString returns the list of action names as comma separated string func (r *EventRule) GetActionsAsString() string { actions := make([]string, 0, len(r.Actions)) for _, action := range r.Actions { actions = append(actions, action.Name) } return strings.Join(actions, ",") } func (r *EventRule) isStatusValid() bool { return r.Status >= 0 && r.Status <= 1 } func (r *EventRule) validate() error { //nolint:gocyclo if r.Name == "" { return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) } if !util.IsNameValid(r.Name) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if config.NamingRules&1 == 0 && !usernameRegex.MatchString(r.Name) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", r.Name)), util.I18nErrorInvalidUser, ) } if !r.isStatusValid() { return util.NewValidationError(fmt.Sprintf("invalid event rule status: %d", r.Status)) } if !isEventTriggerValid(r.Trigger) { return util.NewValidationError(fmt.Sprintf("invalid event rule trigger: %d", r.Trigger)) } if err := r.Conditions.validate(r.Trigger); err != nil { return err } if len(r.Actions) == 0 { return util.NewI18nError(util.NewValidationError("at least one action is required"), util.I18nErrorRuleActionRequired) } actionNames := make(map[string]bool) actionOrders := make(map[int]bool) failureActions := 0 hasSyncAction := false for idx := range r.Actions { if r.Actions[idx].Name == "" { return util.NewValidationError(fmt.Sprintf("invalid action at position %d, name not specified", idx)) } if actionNames[r.Actions[idx].Name] { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("duplicated action %q", r.Actions[idx].Name)), util.I18nErrorRuleDuplicateActions, ) } if actionOrders[r.Actions[idx].Order] { return util.NewValidationError(fmt.Sprintf("duplicated order %d for action %q", r.Actions[idx].Order, r.Actions[idx].Name)) } if err := r.Actions[idx].validateAssociation(r.Trigger, r.Conditions.FsEvents); err != nil { return err } if r.Actions[idx].Options.IsFailureAction { failureActions++ } if r.Actions[idx].Options.ExecuteSync { hasSyncAction = true } actionNames[r.Actions[idx].Name] = true actionOrders[r.Actions[idx].Order] = true } if len(r.Actions) == failureActions { return util.NewI18nError( util.NewValidationError("at least a non-failure action is required"), util.I18nErrorRuleFailureActionsOnly, ) } if !hasSyncAction { return r.validateMandatorySyncActions() } return nil } func (r *EventRule) validateMandatorySyncActions() error { if r.Trigger != EventTriggerFsEvent { return nil } for _, ev := range r.Conditions.FsEvents { if slices.Contains(mandatorySyncFsEvents, ev) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("event %q requires at least a sync action", ev)), util.I18nErrorRuleSyncActionRequired, util.I18nErrorArgs(map[string]any{ "val": ev, }), ) } } return nil } func (r *EventRule) checkIPBlockedAndCertificateActions() error { unavailableActions := []int{ActionTypeUserQuotaReset, ActionTypeFolderQuotaReset, ActionTypeTransferQuotaReset, ActionTypeDataRetentionCheck, ActionTypeFilesystem, ActionTypePasswordExpirationCheck, ActionTypeUserExpirationCheck} for _, action := range r.Actions { if slices.Contains(unavailableActions, action.Type) { return fmt.Errorf("action %q, type %q is not supported for event trigger %q", action.Name, getActionTypeAsString(action.Type), getTriggerTypeAsString(r.Trigger)) } } return nil } func (r *EventRule) checkProviderEventActions(providerObjectType string) error { // user quota reset, transfer quota reset, data retention check and filesystem actions // can be executed only if we modify a user. They will be executed for the // affected user. Folder quota reset can be executed only for folders. userSpecificActions := []int{ActionTypeUserQuotaReset, ActionTypeTransferQuotaReset, ActionTypeDataRetentionCheck, ActionTypeFilesystem, ActionTypePasswordExpirationCheck, ActionTypeUserExpirationCheck} for _, action := range r.Actions { if slices.Contains(userSpecificActions, action.Type) && providerObjectType != actionObjectUser { return fmt.Errorf("action %q, type %q is only supported for provider user events", action.Name, getActionTypeAsString(action.Type)) } if action.Type == ActionTypeFolderQuotaReset && providerObjectType != actionObjectFolder { return fmt.Errorf("action %q, type %q is only supported for provider folder events", action.Name, getActionTypeAsString(action.Type)) } } return nil } func (r *EventRule) hasUserAssociated(providerObjectType string) bool { switch r.Trigger { case EventTriggerProviderEvent: return providerObjectType == actionObjectUser case EventTriggerFsEvent: return true default: if len(r.Actions) > 0 { // should we allow schedules where backup is not the first action? // maybe we could pass the action index and check before that index return r.Actions[0].Type == ActionTypeBackup } } return false } func (r *EventRule) checkActions(providerObjectType string) error { numSyncAction := 0 hasIDPAccountCheck := false for _, action := range r.Actions { if action.Options.ExecuteSync { numSyncAction++ } if action.Type == ActionTypeEmail && action.BaseEventAction.Options.EmailConfig.hasFilesAttachments() { if !r.hasUserAssociated(providerObjectType) { return errors.New("cannot send an email with attachments for a rule with no user associated") } } if action.Type == ActionTypeHTTP && action.BaseEventAction.Options.HTTPConfig.HasMultipartFiles() { if !r.hasUserAssociated(providerObjectType) { return errors.New("cannot upload file/s for a rule with no user associated") } } if action.Type == ActionTypeIDPAccountCheck { if r.Trigger != EventTriggerIDPLogin { return errors.New("IDP account check action is only supported for IDP login trigger") } if !action.Options.ExecuteSync { return errors.New("IDP account check must be a sync action") } hasIDPAccountCheck = true } } if hasIDPAccountCheck && numSyncAction != 1 { return errors.New("IDP account check must be the only sync action") } return nil } // CheckActionsConsistency returns an error if the actions cannot be executed func (r *EventRule) CheckActionsConsistency(providerObjectType string) error { switch r.Trigger { case EventTriggerProviderEvent: if err := r.checkProviderEventActions(providerObjectType); err != nil { return err } case EventTriggerFsEvent: // folder quota reset cannot be executed for _, action := range r.Actions { if action.Type == ActionTypeFolderQuotaReset { return fmt.Errorf("action %q, type %q is not supported for filesystem events", action.Name, getActionTypeAsString(action.Type)) } } case EventTriggerIPBlocked, EventTriggerCertificate: if err := r.checkIPBlockedAndCertificateActions(); err != nil { return err } } return r.checkActions(providerObjectType) } // PrepareForRendering prepares an EventRule for rendering. // It hides confidential data and set to nil the empty secrets // so they are not serialized func (r *EventRule) PrepareForRendering() { for idx := range r.Actions { r.Actions[idx].PrepareForRendering() } } // RenderAsJSON implements the renderer interface used within plugins func (r *EventRule) RenderAsJSON(reload bool) ([]byte, error) { if reload { rule, err := provider.eventRuleExists(r.Name) if err != nil { providerLog(logger.LevelError, "unable to reload event rule before rendering as json: %v", err) return nil, err } rule.PrepareForRendering() return json.Marshal(rule) } r.PrepareForRendering() return json.Marshal(r) } func cloneRenameConfigs(renames []RenameConfig) []RenameConfig { res := make([]RenameConfig, 0, len(renames)) for _, c := range renames { res = append(res, RenameConfig{ KeyValue: KeyValue{ Key: c.Key, Value: c.Value, }, UpdateModTime: c.UpdateModTime, }) } return res } func cloneKeyValues(keyVals []KeyValue) []KeyValue { res := make([]KeyValue, 0, len(keyVals)) for _, kv := range keyVals { res = append(res, KeyValue{ Key: kv.Key, Value: kv.Value, }) } return res } func cloneConditionPatterns(patterns []ConditionPattern) []ConditionPattern { res := make([]ConditionPattern, 0, len(patterns)) for _, p := range patterns { res = append(res, ConditionPattern{ Pattern: p.Pattern, InverseMatch: p.InverseMatch, }) } return res } func validateConditionPatterns(patterns []ConditionPattern) error { for _, name := range patterns { if err := name.validate(); err != nil { return err } } return nil } // Task stores the state for a scheduled task type Task struct { Name string `json:"name"` UpdateAt int64 `json:"updated_at"` Version int64 `json:"version"` } ================================================ FILE: internal/dataprovider/group.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "encoding/json" "fmt" "path/filepath" "strings" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) // GroupUserSettings defines the settings to apply to users type GroupUserSettings struct { sdk.BaseGroupUserSettings // Filesystem configuration details FsConfig vfs.Filesystem `json:"filesystem"` } // Group defines an SFTPGo group. // Groups are used to easily configure similar users type Group struct { sdk.BaseGroup // settings to apply to users for whom this is a primary group UserSettings GroupUserSettings `json:"user_settings,omitempty"` // Mapping between virtual paths and virtual folders VirtualFolders []vfs.VirtualFolder `json:"virtual_folders,omitempty"` } // GetPermissions returns the permissions as list func (g *Group) GetPermissions() []sdk.DirectoryPermissions { result := make([]sdk.DirectoryPermissions, 0, len(g.UserSettings.Permissions)) for k, v := range g.UserSettings.Permissions { result = append(result, sdk.DirectoryPermissions{ Path: k, Permissions: v, }) } return result } // GetAllowedIPAsString returns the allowed IP as comma separated string func (g *Group) GetAllowedIPAsString() string { return strings.Join(g.UserSettings.Filters.AllowedIP, ",") } // GetDeniedIPAsString returns the denied IP as comma separated string func (g *Group) GetDeniedIPAsString() string { return strings.Join(g.UserSettings.Filters.DeniedIP, ",") } // HasExternalAuth returns true if the external authentication is globally enabled // and it is not disabled for this group func (g *Group) HasExternalAuth() bool { if g.UserSettings.Filters.Hooks.ExternalAuthDisabled { return false } if config.ExternalAuthHook != "" { return true } return plugin.Handler.HasAuthenticators() } // SetEmptySecretsIfNil sets the secrets to empty if nil func (g *Group) SetEmptySecretsIfNil() { g.UserSettings.FsConfig.SetEmptySecretsIfNil() for idx := range g.VirtualFolders { vfolder := &g.VirtualFolders[idx] vfolder.FsConfig.SetEmptySecretsIfNil() } } // PrepareForRendering prepares a group for rendering. // It hides confidential data and set to nil the empty secrets // so they are not serialized func (g *Group) PrepareForRendering() { g.UserSettings.FsConfig.HideConfidentialData() g.UserSettings.FsConfig.SetNilSecretsIfEmpty() for idx := range g.VirtualFolders { folder := &g.VirtualFolders[idx] folder.PrepareForRendering() } } // RenderAsJSON implements the renderer interface used within plugins func (g *Group) RenderAsJSON(reload bool) ([]byte, error) { if reload { group, err := provider.groupExists(g.Name) if err != nil { providerLog(logger.LevelError, "unable to reload group before rendering as json: %v", err) return nil, err } group.PrepareForRendering() return json.Marshal(group) } g.PrepareForRendering() return json.Marshal(g) } // GetEncryptionAdditionalData returns the additional data to use for AEAD func (g *Group) GetEncryptionAdditionalData() string { return fmt.Sprintf("group_%v", g.Name) } // HasRedactedSecret returns true if the user has a redacted secret func (g *Group) hasRedactedSecret() bool { for idx := range g.VirtualFolders { folder := &g.VirtualFolders[idx] if folder.HasRedactedSecret() { return true } } return g.UserSettings.FsConfig.HasRedactedSecret() } func (g *Group) applyNamingRules() { g.Name = config.convertName(g.Name) for idx := range g.VirtualFolders { g.VirtualFolders[idx].Name = config.convertName(g.VirtualFolders[idx].Name) } } func (g *Group) validate() error { g.SetEmptySecretsIfNil() g.applyNamingRules() if g.Name == "" { return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) } if !util.IsNameValid(g.Name) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if config.NamingRules&1 == 0 && !usernameRegex.MatchString(g.Name) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", g.Name)), util.I18nErrorInvalidName, ) } if g.hasRedactedSecret() { return util.NewValidationError("cannot save a group with a redacted secret") } vfolders, err := validateAssociatedVirtualFolders(g.VirtualFolders) if err != nil { return err } g.VirtualFolders = vfolders return g.validateUserSettings() } func (g *Group) validateUserSettings() error { if g.UserSettings.HomeDir != "" { g.UserSettings.HomeDir = filepath.Clean(g.UserSettings.HomeDir) if !filepath.IsAbs(g.UserSettings.HomeDir) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("home_dir must be an absolute path, actual value: %v", g.UserSettings.HomeDir)), util.I18nErrorInvalidHomeDir, ) } } if err := g.UserSettings.FsConfig.Validate(g.GetEncryptionAdditionalData()); err != nil { return err } if g.UserSettings.TotalDataTransfer > 0 { // if a total data transfer is defined we reset the separate upload and download limits g.UserSettings.UploadDataTransfer = 0 g.UserSettings.DownloadDataTransfer = 0 } if len(g.UserSettings.Permissions) > 0 { permissions, err := validateUserPermissions(g.UserSettings.Permissions) if err != nil { return util.NewI18nError(err, util.I18nErrorGenericPermission) } g.UserSettings.Permissions = permissions } g.UserSettings.Filters.TLSCerts = nil if err := validateBaseFilters(&g.UserSettings.Filters); err != nil { return err } if !g.HasExternalAuth() { g.UserSettings.Filters.ExternalAuthCacheTime = 0 } g.UserSettings.Filters.UserType = "" return nil } func (g *Group) getACopy() Group { users := make([]string, len(g.Users)) copy(users, g.Users) admins := make([]string, len(g.Admins)) copy(admins, g.Admins) virtualFolders := make([]vfs.VirtualFolder, 0, len(g.VirtualFolders)) for idx := range g.VirtualFolders { vfolder := g.VirtualFolders[idx].GetACopy() virtualFolders = append(virtualFolders, vfolder) } permissions := make(map[string][]string) for k, v := range g.UserSettings.Permissions { perms := make([]string, len(v)) copy(perms, v) permissions[k] = perms } return Group{ BaseGroup: sdk.BaseGroup{ ID: g.ID, Name: g.Name, Description: g.Description, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, Users: users, Admins: admins, }, UserSettings: GroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{ HomeDir: g.UserSettings.HomeDir, MaxSessions: g.UserSettings.MaxSessions, QuotaSize: g.UserSettings.QuotaSize, QuotaFiles: g.UserSettings.QuotaFiles, Permissions: permissions, UploadBandwidth: g.UserSettings.UploadBandwidth, DownloadBandwidth: g.UserSettings.DownloadBandwidth, UploadDataTransfer: g.UserSettings.UploadDataTransfer, DownloadDataTransfer: g.UserSettings.DownloadDataTransfer, TotalDataTransfer: g.UserSettings.TotalDataTransfer, ExpiresIn: g.UserSettings.ExpiresIn, Filters: copyBaseUserFilters(g.UserSettings.Filters), }, FsConfig: g.UserSettings.FsConfig.GetACopy(), }, VirtualFolders: virtualFolders, } } ================================================ FILE: internal/dataprovider/iplist.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "encoding/json" "fmt" "net" "net/netip" "slices" "strings" "sync" "sync/atomic" "github.com/yl2chen/cidranger" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( // maximum number of entries to match in memory // if the list contains more elements than this limit a // database query will be executed ipListMemoryLimit = 15000 ) var ( inMemoryLists map[IPListType]*IPList ) func init() { inMemoryLists = map[IPListType]*IPList{} } // IPListType is the enumerable for the supported IP list types type IPListType int // AsString returns the string representation for the list type func (t IPListType) AsString() string { switch t { case IPListTypeAllowList: return "Allow list" case IPListTypeDefender: return "Defender" case IPListTypeRateLimiterSafeList: return "Rate limiters safe list" default: return "" } } // Supported IP list types const ( IPListTypeAllowList IPListType = iota + 1 IPListTypeDefender IPListTypeRateLimiterSafeList ) // Supported IP list modes const ( ListModeAllow = iota + 1 ListModeDeny ) const ( ipTypeV4 = iota + 1 ipTypeV6 ) var ( supportedIPListType = []IPListType{IPListTypeAllowList, IPListTypeDefender, IPListTypeRateLimiterSafeList} ) // CheckIPListType returns an error if the provided IP list type is not valid func CheckIPListType(t IPListType) error { if !slices.Contains(supportedIPListType, t) { return util.NewValidationError(fmt.Sprintf("invalid list type %d", t)) } return nil } // IPListEntry defines an entry for the IP addresses list type IPListEntry struct { IPOrNet string `json:"ipornet"` Description string `json:"description,omitempty"` Type IPListType `json:"type"` Mode int `json:"mode"` // Defines the protocols the entry applies to // - 0 all the supported protocols // - 1 SSH // - 2 FTP // - 4 WebDAV // - 8 HTTP // Protocols can be combined Protocols int `json:"protocols"` First []byte `json:"first,omitempty"` Last []byte `json:"last,omitempty"` IPType int `json:"ip_type,omitempty"` // Creation time as unix timestamp in milliseconds CreatedAt int64 `json:"created_at"` // last update time as unix timestamp in milliseconds UpdatedAt int64 `json:"updated_at"` // in multi node setups we mark the rule as deleted to be able to update the cache DeletedAt int64 `json:"-"` } // PrepareForRendering prepares an IP list entry for rendering. // It hides internal fields func (e *IPListEntry) PrepareForRendering() { e.First = nil e.Last = nil e.IPType = 0 } // HasProtocol returns true if the specified protocol is defined func (e *IPListEntry) HasProtocol(proto string) bool { switch proto { case protocolSSH: return e.Protocols&1 != 0 case protocolFTP: return e.Protocols&2 != 0 case protocolWebDAV: return e.Protocols&4 != 0 case protocolHTTP: return e.Protocols&8 != 0 default: return false } } // RenderAsJSON implements the renderer interface used within plugins func (e *IPListEntry) RenderAsJSON(reload bool) ([]byte, error) { if reload { entry, err := provider.ipListEntryExists(e.IPOrNet, e.Type) if err != nil { providerLog(logger.LevelError, "unable to reload IP list entry before rendering as json: %v", err) return nil, err } entry.PrepareForRendering() return json.Marshal(entry) } e.PrepareForRendering() return json.Marshal(e) } func (e *IPListEntry) getKey() string { return fmt.Sprintf("%d_%s", e.Type, e.IPOrNet) } func (e *IPListEntry) getName() string { return e.Type.AsString() + "-" + e.IPOrNet } func (e *IPListEntry) getFirst() netip.Addr { if e.IPType == ipTypeV4 { var a4 [4]byte copy(a4[:], e.First) return netip.AddrFrom4(a4) } var a16 [16]byte copy(a16[:], e.First) return netip.AddrFrom16(a16) } func (e *IPListEntry) getLast() netip.Addr { if e.IPType == ipTypeV4 { var a4 [4]byte copy(a4[:], e.Last) return netip.AddrFrom4(a4) } var a16 [16]byte copy(a16[:], e.Last) return netip.AddrFrom16(a16) } func (e *IPListEntry) checkProtocols() { for _, proto := range ValidProtocols { if !e.HasProtocol(proto) { return } } e.Protocols = 0 } func (e *IPListEntry) validate() error { if err := CheckIPListType(e.Type); err != nil { return err } e.checkProtocols() switch e.Type { case IPListTypeDefender: if e.Mode < ListModeAllow || e.Mode > ListModeDeny { return util.NewValidationError(fmt.Sprintf("invalid list mode: %d", e.Mode)) } default: if e.Mode != ListModeAllow { return util.NewValidationError("invalid list mode") } } e.PrepareForRendering() if !strings.Contains(e.IPOrNet, "/") { // parse as IP parsed, err := netip.ParseAddr(e.IPOrNet) if err != nil { return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid IP %q", e.IPOrNet)), util.I18nErrorIPInvalid) } if parsed.Is4() { e.IPOrNet += "/32" } else if parsed.Is4In6() { e.IPOrNet = netip.AddrFrom4(parsed.As4()).String() + "/32" } else { e.IPOrNet += "/128" } } prefix, err := netip.ParsePrefix(e.IPOrNet) if err != nil { return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid network %q: %v", e.IPOrNet, err)), util.I18nErrorNetInvalid) } prefix = prefix.Masked() if prefix.Addr().Is4In6() { e.IPOrNet = fmt.Sprintf("%s/%d", netip.AddrFrom4(prefix.Addr().As4()).String(), prefix.Bits()-96) } // TODO: to remove when the in memory ranger switch to netip _, _, err = net.ParseCIDR(e.IPOrNet) if err != nil { return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid network: %v", err)), util.I18nErrorNetInvalid) } if prefix.Addr().Is4() || prefix.Addr().Is4In6() { e.IPType = ipTypeV4 first := prefix.Addr().As4() last := util.GetLastIPForPrefix(prefix).As4() e.First = first[:] e.Last = last[:] } else { e.IPType = ipTypeV6 first := prefix.Addr().As16() last := util.GetLastIPForPrefix(prefix).As16() e.First = first[:] e.Last = last[:] } return nil } func (e *IPListEntry) getACopy() IPListEntry { first := make([]byte, len(e.First)) copy(first, e.First) last := make([]byte, len(e.Last)) copy(last, e.Last) return IPListEntry{ IPOrNet: e.IPOrNet, Description: e.Description, Type: e.Type, Mode: e.Mode, First: first, Last: last, IPType: e.IPType, Protocols: e.Protocols, CreatedAt: e.CreatedAt, UpdatedAt: e.UpdatedAt, DeletedAt: e.DeletedAt, } } // getAsRangerEntry returns the entry as cidranger.RangerEntry func (e *IPListEntry) getAsRangerEntry() (cidranger.RangerEntry, error) { _, network, err := net.ParseCIDR(e.IPOrNet) if err != nil { return nil, err } entry := e.getACopy() return &rangerEntry{ entry: &entry, network: *network, }, nil } func (e IPListEntry) satisfySearchConstraints(filter, from, order string) bool { if filter != "" && !strings.HasPrefix(e.IPOrNet, filter) { return false } if from != "" { if order == OrderASC { return e.IPOrNet > from } return e.IPOrNet < from } return true } type rangerEntry struct { entry *IPListEntry network net.IPNet } func (e *rangerEntry) Network() net.IPNet { return e.network } // IPList defines an IP list type IPList struct { isInMemory atomic.Bool listType IPListType mu sync.RWMutex Ranges cidranger.Ranger } func (l *IPList) addEntry(e *IPListEntry) { if l.listType != e.Type { return } if !l.isInMemory.Load() { return } entry, err := e.getAsRangerEntry() if err != nil { providerLog(logger.LevelError, "unable to get entry to add %q for list type %d, disabling memory mode, err: %v", e.IPOrNet, l.listType, err) l.isInMemory.Store(false) return } l.mu.Lock() defer l.mu.Unlock() if err := l.Ranges.Insert(entry); err != nil { providerLog(logger.LevelError, "unable to add entry %q for list type %d, disabling memory mode, err: %v", e.IPOrNet, l.listType, err) l.isInMemory.Store(false) return } if l.Ranges.Len() >= ipListMemoryLimit { providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType) l.isInMemory.Store(false) } } func (l *IPList) removeEntry(e *IPListEntry) { if l.listType != e.Type { return } if !l.isInMemory.Load() { return } entry, err := e.getAsRangerEntry() if err != nil { providerLog(logger.LevelError, "unable to get entry to remove %q for list type %d, disabling memory mode, err: %v", e.IPOrNet, l.listType, err) l.isInMemory.Store(false) return } l.mu.Lock() defer l.mu.Unlock() if _, err := l.Ranges.Remove(entry.Network()); err != nil { providerLog(logger.LevelError, "unable to remove entry %q for list type %d, disabling memory mode, err: %v", e.IPOrNet, l.listType, err) l.isInMemory.Store(false) } } func (l *IPList) updateEntry(e *IPListEntry) { if l.listType != e.Type { return } if !l.isInMemory.Load() { return } entry, err := e.getAsRangerEntry() if err != nil { providerLog(logger.LevelError, "unable to get entry to update %q for list type %d, disabling memory mode, err: %v", e.IPOrNet, l.listType, err) l.isInMemory.Store(false) return } l.mu.Lock() defer l.mu.Unlock() if _, err := l.Ranges.Remove(entry.Network()); err != nil { providerLog(logger.LevelError, "unable to remove entry to update %q for list type %d, disabling memory mode, err: %v", e.IPOrNet, l.listType, err) l.isInMemory.Store(false) return } if err := l.Ranges.Insert(entry); err != nil { providerLog(logger.LevelError, "unable to add entry to update %q for list type %d, disabling memory mode, err: %v", e.IPOrNet, l.listType, err) l.isInMemory.Store(false) } if l.Ranges.Len() >= ipListMemoryLimit { providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType) l.isInMemory.Store(false) } } // DisableMemoryMode disables memory mode forcing database queries func (l *IPList) DisableMemoryMode() { l.isInMemory.Store(false) } // IsListed checks if there is a match for the specified IP and protocol. // If there are multiple matches, the first one is returned, in no particular order, // so the behavior is undefined func (l *IPList) IsListed(ip, protocol string) (bool, int, error) { if l.isInMemory.Load() { l.mu.RLock() defer l.mu.RUnlock() if l.Ranges.Len() == 0 { return false, 0, nil } parsedIP := net.ParseIP(ip) if parsedIP == nil { return false, 0, fmt.Errorf("invalid IP %s", ip) } entries, err := l.Ranges.ContainingNetworks(parsedIP) if err != nil { return false, 0, fmt.Errorf("unable to find containing networks for ip %q: %w", ip, err) } for _, e := range entries { entry, ok := e.(*rangerEntry) if ok { if entry.entry.Protocols == 0 || entry.entry.HasProtocol(protocol) { return true, entry.entry.Mode, nil } } } return false, 0, nil } entries, err := provider.getListEntriesForIP(ip, l.listType) if err != nil { return false, 0, err } for _, e := range entries { if e.Protocols == 0 || e.HasProtocol(protocol) { return true, e.Mode, nil } } return false, 0, nil } // NewIPList returns a new IP list for the specified type func NewIPList(listType IPListType) (*IPList, error) { delete(inMemoryLists, listType) count, err := provider.countIPListEntries(listType) if err != nil { return nil, err } if count < ipListMemoryLimit { providerLog(logger.LevelInfo, "using in-memory matching for list type %d, num entries: %d", listType, count) entries, err := provider.getIPListEntries(listType, "", "", OrderASC, 0) if err != nil { return nil, err } ipList := &IPList{ listType: listType, Ranges: cidranger.NewPCTrieRanger(), } for idx := range entries { e := entries[idx] entry, err := e.getAsRangerEntry() if err != nil { return nil, fmt.Errorf("unable to get ranger for entry %q: %w", e.IPOrNet, err) } if err := ipList.Ranges.Insert(entry); err != nil { return nil, fmt.Errorf("unable to add ranger for entry %q: %w", e.IPOrNet, err) } } ipList.isInMemory.Store(true) inMemoryLists[listType] = ipList return ipList, nil } providerLog(logger.LevelInfo, "list type %d has %d entries, in-memory matching disabled", listType, count) ipList := &IPList{ listType: listType, Ranges: nil, } ipList.isInMemory.Store(false) return ipList, nil } ================================================ FILE: internal/dataprovider/memory.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "bytes" "crypto/x509" "errors" "fmt" "net/netip" "os" "path/filepath" "slices" "sort" "strconv" "sync" "time" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) var ( errMemoryProviderClosed = errors.New("memory provider is closed") ) type memoryProviderHandle struct { // configuration file to use for loading users configFile string sync.Mutex isClosed bool // slice with ordered usernames usernames []string // map for users, username is the key users map[string]User // slice with ordered group names groupnames []string // map for group, group name is the key groups map[string]Group // map for virtual folders, folder name is the key vfolders map[string]vfs.BaseVirtualFolder // slice with ordered folder names vfoldersNames []string // map for admins, username is the key admins map[string]Admin // slice with ordered admins adminsUsernames []string // map for API keys, keyID is the key apiKeys map[string]APIKey // slice with ordered API keys KeyID apiKeysIDs []string // map for shares, shareID is the key shares map[string]Share // slice with ordered shares shareID sharesIDs []string // map for event actions, name is the key actions map[string]BaseEventAction // slice with ordered actions actionsNames []string // map for event actions, name is the key rules map[string]EventRule // slice with ordered rules rulesNames []string // map for roles, name is the key roles map[string]Role // slice with ordered roles roleNames []string // map for IP List entry ipListEntries map[string]IPListEntry // slice with ordered IP list entries ipListEntriesKeys []string // configurations configs Configs } // MemoryProvider defines the auth provider for a memory store type MemoryProvider struct { dbHandle *memoryProviderHandle } func initializeMemoryProvider(basePath string) error { configFile := "" if util.IsFileInputValid(config.Name) { configFile = config.Name if !filepath.IsAbs(configFile) { configFile = filepath.Join(basePath, configFile) } } provider = &MemoryProvider{ dbHandle: &memoryProviderHandle{ isClosed: false, usernames: []string{}, users: make(map[string]User), groupnames: []string{}, groups: make(map[string]Group), vfolders: make(map[string]vfs.BaseVirtualFolder), vfoldersNames: []string{}, admins: make(map[string]Admin), adminsUsernames: []string{}, apiKeys: make(map[string]APIKey), apiKeysIDs: []string{}, shares: make(map[string]Share), sharesIDs: []string{}, actions: make(map[string]BaseEventAction), actionsNames: []string{}, rules: make(map[string]EventRule), rulesNames: []string{}, roles: map[string]Role{}, roleNames: []string{}, ipListEntries: map[string]IPListEntry{}, ipListEntriesKeys: []string{}, configs: Configs{}, configFile: configFile, }, } return provider.reloadConfig() } func (p *MemoryProvider) checkAvailability() error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } return nil } func (p *MemoryProvider) close() error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } p.dbHandle.isClosed = true return nil } func (p *MemoryProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { var user User if tlsCert == nil { return user, errors.New("TLS certificate cannot be null or empty") } user, err := p.userExists(username, "") if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, err } return checkUserAndTLSCertificate(&user, protocol, tlsCert) } func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { user, err := p.userExists(username, "") if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, err } return checkUserAndPass(&user, password, ip, protocol) } func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte, isSSHCert bool) (User, string, error) { var user User if len(pubKey) == 0 { return user, "", errors.New("credentials cannot be null or empty") } user, err := p.userExists(username, "") if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, "", err } return checkUserAndPubKey(&user, pubKey, isSSHCert) } func (p *MemoryProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { admin, err := p.adminExists(username) if err != nil { providerLog(logger.LevelWarn, "error authenticating admin %q: %v", username, err) return admin, err } err = admin.checkUserAndPass(password, ip) return admin, err } func (p *MemoryProvider) updateAPIKeyLastUse(keyID string) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } apiKey, err := p.apiKeyExistsInternal(keyID) if err != nil { return err } apiKey.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.apiKeys[apiKey.KeyID] = apiKey return nil } func (p *MemoryProvider) getAdminSignature(username string) (string, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return "", errMemoryProviderClosed } admin, err := p.adminExistsInternal(username) if err != nil { return "", err } return strconv.FormatInt(admin.UpdatedAt, 10), nil } func (p *MemoryProvider) getUserSignature(username string) (string, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return "", errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { return "", err } return strconv.FormatInt(user.UpdatedAt, 10), nil } func (p *MemoryProvider) setUpdatedAt(username string) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return } user, err := p.userExistsInternal(username) if err != nil { return } user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.users[user.Username] = user setLastUserUpdate() } func (p *MemoryProvider) updateLastLogin(username string) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { return err } user.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.users[user.Username] = user return nil } func (p *MemoryProvider) updateAdminLastLogin(username string) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } admin, err := p.adminExistsInternal(username) if err != nil { return err } admin.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.admins[admin.Username] = admin return nil } func (p *MemoryProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { providerLog(logger.LevelError, "unable to update transfer quota for user %q error: %v", username, err) return err } if reset { user.UsedUploadDataTransfer = uploadSize user.UsedDownloadDataTransfer = downloadSize } else { user.UsedUploadDataTransfer += uploadSize user.UsedDownloadDataTransfer += downloadSize } user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %v dl increment: %v is reset? %v", username, uploadSize, downloadSize, reset) p.dbHandle.users[user.Username] = user return nil } func (p *MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { providerLog(logger.LevelError, "unable to update quota for user %q error: %v", username, err) return err } if reset { user.UsedQuotaSize = sizeAdd user.UsedQuotaFiles = filesAdd } else { user.UsedQuotaSize += sizeAdd user.UsedQuotaFiles += filesAdd } user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %v size increment: %v is reset? %v", username, filesAdd, sizeAdd, reset) p.dbHandle.users[user.Username] = user return nil } func (p *MemoryProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return 0, 0, 0, 0, errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { providerLog(logger.LevelError, "unable to get quota for user %q error: %v", username, err) return 0, 0, 0, 0, err } return user.UsedQuotaFiles, user.UsedQuotaSize, user.UsedUploadDataTransfer, user.UsedDownloadDataTransfer, err } func (p *MemoryProvider) addUser(user *User) error { err := ValidateUser(user) if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err = p.userExistsInternal(user.Username) if err == nil { return util.NewI18nError( fmt.Errorf("%w: username %v already exists", ErrDuplicatedKey, user.Username), util.I18nErrorDuplicatedUsername, ) } user.ID = p.getNextID() user.LastQuotaUpdate = 0 user.UsedQuotaSize = 0 user.UsedQuotaFiles = 0 user.UsedUploadDataTransfer = 0 user.UsedDownloadDataTransfer = 0 user.LastLogin = 0 user.FirstUpload = 0 user.FirstDownload = 0 user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) if err := p.addUserToRole(user.Username, user.Role); err != nil { return err } var mappedGroups []string for idx := range user.Groups { if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil { // try to remove group mapping for _, g := range mappedGroups { p.removeUserFromGroupMapping(user.Username, g) } return err } mappedGroups = append(mappedGroups, user.Groups[idx].Name) } var mappedFolders []string for idx := range user.VirtualFolders { if err = p.addUserToFolderMapping(user.Username, user.VirtualFolders[idx].Name); err != nil { // try to remove folder mapping for _, f := range mappedFolders { p.removeRelationFromFolderMapping(f, user.Username, "") } return err } mappedFolders = append(mappedFolders, user.VirtualFolders[idx].Name) } p.dbHandle.users[user.Username] = user.getACopy() p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username) sort.Strings(p.dbHandle.usernames) return nil } func (p *MemoryProvider) updateUser(user *User) error { //nolint:gocyclo err := ValidateUser(user) if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } u, err := p.userExistsInternal(user.Username) if err != nil { return err } p.removeUserFromRole(u.Username, u.Role) if err := p.addUserToRole(user.Username, user.Role); err != nil { // try ro add old role if errRollback := p.addUserToRole(u.Username, u.Role); errRollback != nil { providerLog(logger.LevelError, "unable to rollback old role %q for user %q, error: %v", u.Role, u.Username, errRollback) } return err } for idx := range u.Groups { p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name) } for idx := range user.Groups { if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil { // try to add old mapping for _, g := range u.Groups { if errRollback := p.addUserToGroupMapping(user.Username, g.Name); errRollback != nil { providerLog(logger.LevelError, "unable to rollback old group mapping %q for user %q, error: %v", g.Name, user.Username, errRollback) } } return err } } for _, oldFolder := range u.VirtualFolders { p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "") } for idx := range user.VirtualFolders { if err = p.addUserToFolderMapping(user.Username, user.VirtualFolders[idx].Name); err != nil { // try to add old mapping for _, f := range u.VirtualFolders { if errRollback := p.addUserToFolderMapping(user.Username, f.Name); errRollback != nil { providerLog(logger.LevelError, "unable to rollback old folder mapping %q for user %q, error: %v", f.Name, user.Username, errRollback) } } return err } } user.LastQuotaUpdate = u.LastQuotaUpdate user.UsedQuotaSize = u.UsedQuotaSize user.UsedQuotaFiles = u.UsedQuotaFiles user.UsedUploadDataTransfer = u.UsedUploadDataTransfer user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer user.LastLogin = u.LastLogin user.FirstDownload = u.FirstDownload user.FirstUpload = u.FirstUpload user.CreatedAt = u.CreatedAt user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) user.ID = u.ID // pre-login and external auth hook will use the passed *user so save a copy p.dbHandle.users[user.Username] = user.getACopy() setLastUserUpdate() return nil } func (p *MemoryProvider) deleteUser(user User, _ bool) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } u, err := p.userExistsInternal(user.Username) if err != nil { return err } p.removeUserFromRole(u.Username, u.Role) for _, oldFolder := range u.VirtualFolders { p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "") } for idx := range u.Groups { p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name) } delete(p.dbHandle.users, user.Username) // this could be more efficient p.dbHandle.usernames = make([]string, 0, len(p.dbHandle.users)) for username := range p.dbHandle.users { p.dbHandle.usernames = append(p.dbHandle.usernames, username) } sort.Strings(p.dbHandle.usernames) p.deleteAPIKeysWithUser(user.Username) p.deleteSharesWithUser(user.Username) return nil } func (p *MemoryProvider) updateUserPassword(username, password string) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { return err } user.Password = password user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.users[username] = user return nil } func (p *MemoryProvider) dumpUsers() ([]User, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() users := make([]User, 0, len(p.dbHandle.usernames)) var err error if p.dbHandle.isClosed { return users, errMemoryProviderClosed } for _, username := range p.dbHandle.usernames { u := p.dbHandle.users[username] user := u.getACopy() p.addVirtualFoldersToUser(&user) users = append(users, user) } return users, err } func (p *MemoryProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() folders := make([]vfs.BaseVirtualFolder, 0, len(p.dbHandle.vfoldersNames)) if p.dbHandle.isClosed { return folders, errMemoryProviderClosed } for _, f := range p.dbHandle.vfolders { folders = append(folders, f) } return folders, nil } func (p *MemoryProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { if getLastUserUpdate() < after { return nil, nil } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } users := make([]User, 0, 10) for _, username := range p.dbHandle.usernames { u := p.dbHandle.users[username] if u.UpdatedAt < after { continue } user := u.getACopy() p.addVirtualFoldersToUser(&user) if len(user.Groups) > 0 { groupMapping := make(map[string]Group) for idx := range user.Groups { group, err := p.groupExistsInternal(user.Groups[idx].Name) if err != nil { continue } groupMapping[group.Name] = group } user.applyGroupSettings(groupMapping) } user.SetEmptySecretsIfNil() users = append(users, user) } return users, nil } func (p *MemoryProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { users := make([]User, 0, 30) p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return users, errMemoryProviderClosed } for _, username := range p.dbHandle.usernames { if needFolders, ok := toFetch[username]; ok { u := p.dbHandle.users[username] user := u.getACopy() if needFolders { p.addVirtualFoldersToUser(&user) } if len(user.Groups) > 0 { groupMapping := make(map[string]Group) for idx := range user.Groups { group, err := p.groupExistsInternal(user.Groups[idx].Name) if err != nil { continue } groupMapping[group.Name] = group } user.applyGroupSettings(groupMapping) } user.SetEmptySecretsIfNil() user.PrepareForRendering() users = append(users, user) } } return users, nil } func (p *MemoryProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { users := make([]User, 0, limit) var err error p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return users, errMemoryProviderClosed } if limit <= 0 { return users, err } itNum := 0 if order == OrderASC { for _, username := range p.dbHandle.usernames { itNum++ if itNum <= offset { continue } u := p.dbHandle.users[username] user := u.getACopy() if !user.hasRole(role) { continue } p.addVirtualFoldersToUser(&user) user.PrepareForRendering() users = append(users, user) if len(users) >= limit { break } } } else { for i := len(p.dbHandle.usernames) - 1; i >= 0; i-- { itNum++ if itNum <= offset { continue } username := p.dbHandle.usernames[i] u := p.dbHandle.users[username] user := u.getACopy() if !user.hasRole(role) { continue } p.addVirtualFoldersToUser(&user) user.PrepareForRendering() users = append(users, user) if len(users) >= limit { break } } } return users, err } func (p *MemoryProvider) userExists(username, role string) (User, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return User{}, errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { return user, err } if !user.hasRole(role) { return User{}, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } p.addVirtualFoldersToUser(&user) return user, nil } func (p *MemoryProvider) userExistsInternal(username string) (User, error) { if val, ok := p.dbHandle.users[username]; ok { return val.getACopy(), nil } return User{}, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) } func (p *MemoryProvider) groupExistsInternal(name string) (Group, error) { if val, ok := p.dbHandle.groups[name]; ok { return val.getACopy(), nil } return Group{}, util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", name)) } func (p *MemoryProvider) actionExistsInternal(name string) (BaseEventAction, error) { if val, ok := p.dbHandle.actions[name]; ok { return val.getACopy(), nil } return BaseEventAction{}, util.NewRecordNotFoundError(fmt.Sprintf("event action %q does not exist", name)) } func (p *MemoryProvider) ruleExistsInternal(name string) (EventRule, error) { if val, ok := p.dbHandle.rules[name]; ok { return val.getACopy(), nil } return EventRule{}, util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", name)) } func (p *MemoryProvider) roleExistsInternal(name string) (Role, error) { if val, ok := p.dbHandle.roles[name]; ok { return val.getACopy(), nil } return Role{}, util.NewRecordNotFoundError(fmt.Sprintf("role %q does not exist", name)) } func (p *MemoryProvider) ipListEntryExistsInternal(entry *IPListEntry) (IPListEntry, error) { if val, ok := p.dbHandle.ipListEntries[entry.getKey()]; ok { return val.getACopy(), nil } return IPListEntry{}, util.NewRecordNotFoundError(fmt.Sprintf("IP list entry %q does not exist", entry.getName())) } func (p *MemoryProvider) addAdmin(admin *Admin) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } err := admin.validate() if err != nil { return err } _, err = p.adminExistsInternal(admin.Username) if err == nil { return util.NewI18nError( fmt.Errorf("%w: admin %q already exists", ErrDuplicatedKey, admin.Username), util.I18nErrorDuplicatedUsername, ) } admin.ID = p.getNextAdminID() admin.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) admin.LastLogin = 0 if err := p.addAdminToRole(admin.Username, admin.Role); err != nil { return err } var mappedAdmins []string for idx := range admin.Groups { if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil { // try to remove group mapping for _, g := range mappedAdmins { p.removeAdminFromGroupMapping(admin.Username, g) } return err } mappedAdmins = append(mappedAdmins, admin.Groups[idx].Name) } p.dbHandle.admins[admin.Username] = admin.getACopy() p.dbHandle.adminsUsernames = append(p.dbHandle.adminsUsernames, admin.Username) sort.Strings(p.dbHandle.adminsUsernames) return nil } func (p *MemoryProvider) updateAdmin(admin *Admin) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } err := admin.validate() if err != nil { return err } a, err := p.adminExistsInternal(admin.Username) if err != nil { return err } p.removeAdminFromRole(a.Username, a.Role) if err := p.addAdminToRole(admin.Username, admin.Role); err != nil { // try ro add old role if errRollback := p.addAdminToRole(a.Username, a.Role); errRollback != nil { providerLog(logger.LevelError, "unable to rollback old role %q for admin %q, error: %v", a.Role, a.Username, errRollback) } return err } for idx := range a.Groups { p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name) } for idx := range admin.Groups { if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil { // try to add old mapping for _, oldGroup := range a.Groups { if errRollback := p.addAdminToGroupMapping(a.Username, oldGroup.Name); errRollback != nil { providerLog(logger.LevelError, "unable to rollback old group mapping %q for admin %q, error: %v", oldGroup.Name, a.Username, errRollback) } } return err } } admin.ID = a.ID admin.CreatedAt = a.CreatedAt admin.LastLogin = a.LastLogin admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.admins[admin.Username] = admin.getACopy() return nil } func (p *MemoryProvider) deleteAdmin(admin Admin) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } a, err := p.adminExistsInternal(admin.Username) if err != nil { return err } p.removeAdminFromRole(a.Username, a.Role) for idx := range a.Groups { p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name) } delete(p.dbHandle.admins, admin.Username) // this could be more efficient p.dbHandle.adminsUsernames = make([]string, 0, len(p.dbHandle.admins)) for username := range p.dbHandle.admins { p.dbHandle.adminsUsernames = append(p.dbHandle.adminsUsernames, username) } sort.Strings(p.dbHandle.adminsUsernames) p.deleteAPIKeysWithAdmin(admin.Username) return nil } func (p *MemoryProvider) adminExists(username string) (Admin, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return Admin{}, errMemoryProviderClosed } return p.adminExistsInternal(username) } func (p *MemoryProvider) adminExistsInternal(username string) (Admin, error) { if val, ok := p.dbHandle.admins[username]; ok { return val.getACopy(), nil } return Admin{}, util.NewRecordNotFoundError(fmt.Sprintf("admin %q does not exist", username)) } func (p *MemoryProvider) dumpAdmins() ([]Admin, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() admins := make([]Admin, 0, len(p.dbHandle.admins)) if p.dbHandle.isClosed { return admins, errMemoryProviderClosed } for _, admin := range p.dbHandle.admins { admins = append(admins, admin) } return admins, nil } func (p *MemoryProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { admins := make([]Admin, 0, limit) p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return admins, errMemoryProviderClosed } if limit <= 0 { return admins, nil } itNum := 0 if order == OrderASC { for _, username := range p.dbHandle.adminsUsernames { itNum++ if itNum <= offset { continue } a := p.dbHandle.admins[username] admin := a.getACopy() admin.HideConfidentialData() admins = append(admins, admin) if len(admins) >= limit { break } } } else { for i := len(p.dbHandle.adminsUsernames) - 1; i >= 0; i-- { itNum++ if itNum <= offset { continue } username := p.dbHandle.adminsUsernames[i] a := p.dbHandle.admins[username] admin := a.getACopy() admin.HideConfidentialData() admins = append(admins, admin) if len(admins) >= limit { break } } } return admins, nil } func (p *MemoryProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } folder, err := p.folderExistsInternal(name) if err != nil { providerLog(logger.LevelError, "unable to update quota for folder %q error: %v", name, err) return err } if reset { folder.UsedQuotaSize = sizeAdd folder.UsedQuotaFiles = filesAdd } else { folder.UsedQuotaSize += sizeAdd folder.UsedQuotaFiles += filesAdd } folder.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.vfolders[name] = folder return nil } func (p *MemoryProvider) getGroups(limit, offset int, order string, _ bool) ([]Group, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } if limit <= 0 { return nil, nil } groups := make([]Group, 0, limit) itNum := 0 if order == OrderASC { for _, name := range p.dbHandle.groupnames { itNum++ if itNum <= offset { continue } g := p.dbHandle.groups[name] group := g.getACopy() p.addVirtualFoldersToGroup(&group) group.PrepareForRendering() groups = append(groups, group) if len(groups) >= limit { break } } } else { for i := len(p.dbHandle.groupnames) - 1; i >= 0; i-- { itNum++ if itNum <= offset { continue } name := p.dbHandle.groupnames[i] g := p.dbHandle.groups[name] group := g.getACopy() p.addVirtualFoldersToGroup(&group) group.PrepareForRendering() groups = append(groups, group) if len(groups) >= limit { break } } } return groups, nil } func (p *MemoryProvider) getGroupsWithNames(names []string) ([]Group, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } groups := make([]Group, 0, len(names)) for _, name := range names { if val, ok := p.dbHandle.groups[name]; ok { group := val.getACopy() p.addVirtualFoldersToGroup(&group) groups = append(groups, group) } } return groups, nil } func (p *MemoryProvider) getUsersInGroups(names []string) ([]string, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } var users []string for _, name := range names { if val, ok := p.dbHandle.groups[name]; ok { group := val.getACopy() users = append(users, group.Users...) } } return users, nil } func (p *MemoryProvider) groupExists(name string) (Group, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return Group{}, errMemoryProviderClosed } group, err := p.groupExistsInternal(name) if err != nil { return group, err } p.addVirtualFoldersToGroup(&group) return group, nil } func (p *MemoryProvider) addGroup(group *Group) error { if err := group.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err := p.groupExistsInternal(group.Name) if err == nil { return util.NewI18nError( fmt.Errorf("%w: group %q already exists", ErrDuplicatedKey, group.Name), util.I18nErrorDuplicatedUsername, ) } group.ID = p.getNextGroupID() group.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.Users = nil group.Admins = nil var mappedFolders []string for idx := range group.VirtualFolders { if err = p.addGroupToFolderMapping(group.Name, group.VirtualFolders[idx].Name); err != nil { // try to remove folder mapping for _, f := range mappedFolders { p.removeRelationFromFolderMapping(f, "", group.Name) } return err } mappedFolders = append(mappedFolders, group.VirtualFolders[idx].Name) } p.dbHandle.groups[group.Name] = group.getACopy() p.dbHandle.groupnames = append(p.dbHandle.groupnames, group.Name) sort.Strings(p.dbHandle.groupnames) return nil } func (p *MemoryProvider) updateGroup(group *Group) error { if err := group.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } g, err := p.groupExistsInternal(group.Name) if err != nil { return err } for _, oldFolder := range g.VirtualFolders { p.removeRelationFromFolderMapping(oldFolder.Name, "", g.Name) } for idx := range group.VirtualFolders { if err = p.addGroupToFolderMapping(group.Name, group.VirtualFolders[idx].Name); err != nil { // try to add old mapping for _, f := range g.VirtualFolders { if errRollback := p.addGroupToFolderMapping(group.Name, f.Name); errRollback != nil { providerLog(logger.LevelError, "unable to rollback old folder mapping %q for group %q, error: %v", f.Name, group.Name, errRollback) } } return err } } group.CreatedAt = g.CreatedAt group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.ID = g.ID group.Users = g.Users group.Admins = g.Admins p.dbHandle.groups[group.Name] = group.getACopy() return nil } func (p *MemoryProvider) deleteGroup(group Group) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } g, err := p.groupExistsInternal(group.Name) if err != nil { return err } if len(g.Users) > 0 { return util.NewValidationError(fmt.Sprintf("the group %q is referenced, it cannot be removed", group.Name)) } for _, oldFolder := range g.VirtualFolders { p.removeRelationFromFolderMapping(oldFolder.Name, "", g.Name) } for _, a := range g.Admins { p.removeGroupFromAdminMapping(g.Name, a) } delete(p.dbHandle.groups, group.Name) // this could be more efficient p.dbHandle.groupnames = make([]string, 0, len(p.dbHandle.groups)) for name := range p.dbHandle.groups { p.dbHandle.groupnames = append(p.dbHandle.groupnames, name) } sort.Strings(p.dbHandle.groupnames) return nil } func (p *MemoryProvider) dumpGroups() ([]Group, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() groups := make([]Group, 0, len(p.dbHandle.groups)) var err error if p.dbHandle.isClosed { return groups, errMemoryProviderClosed } for _, name := range p.dbHandle.groupnames { g := p.dbHandle.groups[name] group := g.getACopy() p.addVirtualFoldersToGroup(&group) groups = append(groups, group) } return groups, err } func (p *MemoryProvider) getUsedFolderQuota(name string) (int, int64, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return 0, 0, errMemoryProviderClosed } folder, err := p.folderExistsInternal(name) if err != nil { providerLog(logger.LevelError, "unable to get quota for folder %q error: %v", name, err) return 0, 0, err } return folder.UsedQuotaFiles, folder.UsedQuotaSize, err } func (p *MemoryProvider) addVirtualFoldersToGroup(group *Group) { if len(group.VirtualFolders) > 0 { var folders []vfs.VirtualFolder for idx := range group.VirtualFolders { folder := &group.VirtualFolders[idx] baseFolder, err := p.folderExistsInternal(folder.Name) if err != nil { continue } folder.BaseVirtualFolder = baseFolder.GetACopy() folders = append(folders, *folder) } group.VirtualFolders = folders } } func (p *MemoryProvider) addActionsToRule(rule *EventRule) { var actions []EventAction for idx := range rule.Actions { action := &rule.Actions[idx] baseAction, err := p.actionExistsInternal(action.Name) if err != nil { continue } baseAction.Options.SetEmptySecretsIfNil() action.BaseEventAction = baseAction actions = append(actions, *action) } rule.Actions = actions } func (p *MemoryProvider) addRuleToActionMapping(ruleName, actionName string) error { a, err := p.actionExistsInternal(actionName) if err != nil { return util.NewGenericError(fmt.Sprintf("action %q does not exist", actionName)) } if !slices.Contains(a.Rules, ruleName) { a.Rules = append(a.Rules, ruleName) p.dbHandle.actions[actionName] = a } return nil } func (p *MemoryProvider) removeRuleFromActionMapping(ruleName, actionName string) { a, err := p.actionExistsInternal(actionName) if err != nil { providerLog(logger.LevelWarn, "action %q does not exist, cannot remove from mapping", actionName) return } if slices.Contains(a.Rules, ruleName) { var rules []string for _, r := range a.Rules { if r != ruleName { rules = append(rules, r) } } a.Rules = rules p.dbHandle.actions[actionName] = a } } func (p *MemoryProvider) addAdminToGroupMapping(username, groupname string) error { g, err := p.groupExistsInternal(groupname) if err != nil { return err } if !slices.Contains(g.Admins, username) { g.Admins = append(g.Admins, username) p.dbHandle.groups[groupname] = g } return nil } func (p *MemoryProvider) removeAdminFromGroupMapping(username, groupname string) { g, err := p.groupExistsInternal(groupname) if err != nil { return } var admins []string for _, a := range g.Admins { if a != username { admins = append(admins, a) } } g.Admins = admins p.dbHandle.groups[groupname] = g } func (p *MemoryProvider) removeGroupFromAdminMapping(groupname, username string) { admin, err := p.adminExistsInternal(username) if err != nil { // the admin does not exist so there is no associated group return } var newGroups []AdminGroupMapping for _, g := range admin.Groups { if g.Name != groupname { newGroups = append(newGroups, g) } } admin.Groups = newGroups p.dbHandle.admins[admin.Username] = admin } func (p *MemoryProvider) addUserToGroupMapping(username, groupname string) error { g, err := p.groupExistsInternal(groupname) if err != nil { return err } if !slices.Contains(g.Users, username) { g.Users = append(g.Users, username) p.dbHandle.groups[groupname] = g } return nil } func (p *MemoryProvider) removeUserFromGroupMapping(username, groupname string) { g, err := p.groupExistsInternal(groupname) if err != nil { return } var users []string for _, u := range g.Users { if u != username { users = append(users, u) } } g.Users = users p.dbHandle.groups[groupname] = g } func (p *MemoryProvider) addAdminToRole(username, role string) error { if role == "" { return nil } r, err := p.roleExistsInternal(role) if err != nil { return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, role) } if !slices.Contains(r.Admins, username) { r.Admins = append(r.Admins, username) p.dbHandle.roles[role] = r } return nil } func (p *MemoryProvider) removeAdminFromRole(username, role string) { if role == "" { return } r, err := p.roleExistsInternal(role) if err != nil { providerLog(logger.LevelWarn, "role %q does not exist, cannot remove admin %q", role, username) return } var admins []string for _, a := range r.Admins { if a != username { admins = append(admins, a) } } r.Admins = admins p.dbHandle.roles[role] = r } func (p *MemoryProvider) addUserToRole(username, role string) error { if role == "" { return nil } r, err := p.roleExistsInternal(role) if err != nil { return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, role) } if !slices.Contains(r.Users, username) { r.Users = append(r.Users, username) p.dbHandle.roles[role] = r } return nil } func (p *MemoryProvider) removeUserFromRole(username, role string) { if role == "" { return } r, err := p.roleExistsInternal(role) if err != nil { providerLog(logger.LevelWarn, "role %q does not exist, cannot remove user %q", role, username) return } var users []string for _, u := range r.Users { if u != username { users = append(users, u) } } r.Users = users p.dbHandle.roles[role] = r } func (p *MemoryProvider) addUserToFolderMapping(username, foldername string) error { f, err := p.folderExistsInternal(foldername) if err != nil { return util.NewGenericError(fmt.Sprintf("unable to get folder %q: %v", foldername, err)) } if !slices.Contains(f.Users, username) { f.Users = append(f.Users, username) p.dbHandle.vfolders[foldername] = f } return nil } func (p *MemoryProvider) addGroupToFolderMapping(name, foldername string) error { f, err := p.folderExistsInternal(foldername) if err != nil { return util.NewGenericError(fmt.Sprintf("unable to get folder %q: %v", foldername, err)) } if !slices.Contains(f.Groups, name) { f.Groups = append(f.Groups, name) p.dbHandle.vfolders[foldername] = f } return nil } func (p *MemoryProvider) addVirtualFoldersToUser(user *User) { if len(user.VirtualFolders) > 0 { var folders []vfs.VirtualFolder for idx := range user.VirtualFolders { folder := &user.VirtualFolders[idx] baseFolder, err := p.folderExistsInternal(folder.Name) if err != nil { continue } folder.BaseVirtualFolder = baseFolder.GetACopy() folders = append(folders, *folder) } user.VirtualFolders = folders } } func (p *MemoryProvider) removeRelationFromFolderMapping(folderName, username, groupname string) { folder, err := p.folderExistsInternal(folderName) if err != nil { return } if username != "" { var usernames []string for _, user := range folder.Users { if user != username { usernames = append(usernames, user) } } folder.Users = usernames } if groupname != "" { var groups []string for _, group := range folder.Groups { if group != groupname { groups = append(groups, group) } } folder.Groups = groups } p.dbHandle.vfolders[folder.Name] = folder } func (p *MemoryProvider) folderExistsInternal(name string) (vfs.BaseVirtualFolder, error) { if val, ok := p.dbHandle.vfolders[name]; ok { return val, nil } return vfs.BaseVirtualFolder{}, util.NewRecordNotFoundError(fmt.Sprintf("folder %q does not exist", name)) } func (p *MemoryProvider) getFolders(limit, offset int, order string, _ bool) ([]vfs.BaseVirtualFolder, error) { folders := make([]vfs.BaseVirtualFolder, 0, limit) var err error p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return folders, errMemoryProviderClosed } if limit <= 0 { return folders, err } itNum := 0 if order == OrderASC { for _, name := range p.dbHandle.vfoldersNames { itNum++ if itNum <= offset { continue } f := p.dbHandle.vfolders[name] folder := f.GetACopy() folder.PrepareForRendering() folders = append(folders, folder) if len(folders) >= limit { break } } } else { for i := len(p.dbHandle.vfoldersNames) - 1; i >= 0; i-- { itNum++ if itNum <= offset { continue } name := p.dbHandle.vfoldersNames[i] f := p.dbHandle.vfolders[name] folder := f.GetACopy() folder.PrepareForRendering() folders = append(folders, folder) if len(folders) >= limit { break } } } return folders, err } func (p *MemoryProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return vfs.BaseVirtualFolder{}, errMemoryProviderClosed } folder, err := p.folderExistsInternal(name) if err != nil { return vfs.BaseVirtualFolder{}, err } return folder.GetACopy(), nil } func (p *MemoryProvider) addFolder(folder *vfs.BaseVirtualFolder) error { err := ValidateFolder(folder) if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err = p.folderExistsInternal(folder.Name) if err == nil { return util.NewI18nError( fmt.Errorf("%w: folder %q already exists", ErrDuplicatedKey, folder.Name), util.I18nErrorDuplicatedUsername, ) } folder.ID = p.getNextFolderID() folder.Users = nil folder.Groups = nil p.dbHandle.vfolders[folder.Name] = folder.GetACopy() p.dbHandle.vfoldersNames = append(p.dbHandle.vfoldersNames, folder.Name) sort.Strings(p.dbHandle.vfoldersNames) return nil } func (p *MemoryProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { err := ValidateFolder(folder) if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } f, err := p.folderExistsInternal(folder.Name) if err != nil { return err } folder.ID = f.ID folder.LastQuotaUpdate = f.LastQuotaUpdate folder.UsedQuotaFiles = f.UsedQuotaFiles folder.UsedQuotaSize = f.UsedQuotaSize folder.Users = f.Users folder.Groups = f.Groups p.dbHandle.vfolders[folder.Name] = folder.GetACopy() // now update the related users for _, username := range folder.Users { user, err := p.userExistsInternal(username) if err == nil { var folders []vfs.VirtualFolder for idx := range user.VirtualFolders { userFolder := &user.VirtualFolders[idx] if folder.Name == userFolder.Name { userFolder.BaseVirtualFolder = folder.GetACopy() } folders = append(folders, *userFolder) } user.VirtualFolders = folders p.dbHandle.users[user.Username] = user } } return nil } func (p *MemoryProvider) deleteFolder(f vfs.BaseVirtualFolder) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } folder, err := p.folderExistsInternal(f.Name) if err != nil { return err } for _, username := range folder.Users { user, err := p.userExistsInternal(username) if err == nil { var folders []vfs.VirtualFolder for idx := range user.VirtualFolders { userFolder := &user.VirtualFolders[idx] if folder.Name != userFolder.Name { folders = append(folders, *userFolder) } } user.VirtualFolders = folders p.dbHandle.users[user.Username] = user } } for _, groupname := range folder.Groups { group, err := p.groupExistsInternal(groupname) if err == nil { var folders []vfs.VirtualFolder for idx := range group.VirtualFolders { groupFolder := &group.VirtualFolders[idx] if folder.Name != groupFolder.Name { folders = append(folders, *groupFolder) } } group.VirtualFolders = folders p.dbHandle.groups[group.Name] = group } } delete(p.dbHandle.vfolders, folder.Name) p.dbHandle.vfoldersNames = []string{} for name := range p.dbHandle.vfolders { p.dbHandle.vfoldersNames = append(p.dbHandle.vfoldersNames, name) } sort.Strings(p.dbHandle.vfoldersNames) return nil } func (p *MemoryProvider) apiKeyExistsInternal(keyID string) (APIKey, error) { if val, ok := p.dbHandle.apiKeys[keyID]; ok { return val.getACopy(), nil } return APIKey{}, util.NewRecordNotFoundError(fmt.Sprintf("API key %q does not exist", keyID)) } func (p *MemoryProvider) apiKeyExists(keyID string) (APIKey, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return APIKey{}, errMemoryProviderClosed } return p.apiKeyExistsInternal(keyID) } func (p *MemoryProvider) addAPIKey(apiKey *APIKey) error { err := apiKey.validate() if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err = p.apiKeyExistsInternal(apiKey.KeyID) if err == nil { return fmt.Errorf("API key %q already exists", apiKey.KeyID) } if apiKey.User != "" { if _, err := p.userExistsInternal(apiKey.User); err != nil { return fmt.Errorf("%w: related user %q does not exists", ErrForeignKeyViolated, apiKey.User) } } if apiKey.Admin != "" { if _, err := p.adminExistsInternal(apiKey.Admin); err != nil { return fmt.Errorf("%w: related admin %q does not exists", ErrForeignKeyViolated, apiKey.Admin) } } apiKey.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) apiKey.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) apiKey.LastUseAt = 0 p.dbHandle.apiKeys[apiKey.KeyID] = apiKey.getACopy() p.dbHandle.apiKeysIDs = append(p.dbHandle.apiKeysIDs, apiKey.KeyID) sort.Strings(p.dbHandle.apiKeysIDs) return nil } func (p *MemoryProvider) updateAPIKey(apiKey *APIKey) error { err := apiKey.validate() if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } k, err := p.apiKeyExistsInternal(apiKey.KeyID) if err != nil { return err } if apiKey.User != "" { if _, err := p.userExistsInternal(apiKey.User); err != nil { return fmt.Errorf("%w: related user %q does not exists", ErrForeignKeyViolated, apiKey.User) } } if apiKey.Admin != "" { if _, err := p.adminExistsInternal(apiKey.Admin); err != nil { return fmt.Errorf("%w: related admin %q does not exists", ErrForeignKeyViolated, apiKey.Admin) } } apiKey.ID = k.ID apiKey.KeyID = k.KeyID apiKey.Key = k.Key apiKey.CreatedAt = k.CreatedAt apiKey.LastUseAt = k.LastUseAt apiKey.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.apiKeys[apiKey.KeyID] = apiKey.getACopy() return nil } func (p *MemoryProvider) deleteAPIKey(apiKey APIKey) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err := p.apiKeyExistsInternal(apiKey.KeyID) if err != nil { return err } delete(p.dbHandle.apiKeys, apiKey.KeyID) p.updateAPIKeysOrdering() return nil } func (p *MemoryProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { apiKeys := make([]APIKey, 0, limit) p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return apiKeys, errMemoryProviderClosed } if limit <= 0 { return apiKeys, nil } itNum := 0 if order == OrderDESC { for i := len(p.dbHandle.apiKeysIDs) - 1; i >= 0; i-- { itNum++ if itNum <= offset { continue } keyID := p.dbHandle.apiKeysIDs[i] k := p.dbHandle.apiKeys[keyID] apiKey := k.getACopy() apiKey.HideConfidentialData() apiKeys = append(apiKeys, apiKey) if len(apiKeys) >= limit { break } } } else { for _, keyID := range p.dbHandle.apiKeysIDs { itNum++ if itNum <= offset { continue } k := p.dbHandle.apiKeys[keyID] apiKey := k.getACopy() apiKey.HideConfidentialData() apiKeys = append(apiKeys, apiKey) if len(apiKeys) >= limit { break } } } return apiKeys, nil } func (p *MemoryProvider) dumpAPIKeys() ([]APIKey, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() apiKeys := make([]APIKey, 0, len(p.dbHandle.apiKeys)) if p.dbHandle.isClosed { return apiKeys, errMemoryProviderClosed } for _, k := range p.dbHandle.apiKeys { apiKeys = append(apiKeys, k) } return apiKeys, nil } func (p *MemoryProvider) deleteAPIKeysWithUser(username string) { found := false for k, v := range p.dbHandle.apiKeys { if v.User == username { delete(p.dbHandle.apiKeys, k) found = true } } if found { p.updateAPIKeysOrdering() } } func (p *MemoryProvider) deleteAPIKeysWithAdmin(username string) { found := false for k, v := range p.dbHandle.apiKeys { if v.Admin == username { delete(p.dbHandle.apiKeys, k) found = true } } if found { p.updateAPIKeysOrdering() } } func (p *MemoryProvider) deleteSharesWithUser(username string) { found := false for k, v := range p.dbHandle.shares { if v.Username == username { delete(p.dbHandle.shares, k) found = true } } if found { p.updateSharesOrdering() } } func (p *MemoryProvider) updateAPIKeysOrdering() { // this could be more efficient p.dbHandle.apiKeysIDs = make([]string, 0, len(p.dbHandle.apiKeys)) for keyID := range p.dbHandle.apiKeys { p.dbHandle.apiKeysIDs = append(p.dbHandle.apiKeysIDs, keyID) } sort.Strings(p.dbHandle.apiKeysIDs) } func (p *MemoryProvider) updateSharesOrdering() { // this could be more efficient p.dbHandle.sharesIDs = make([]string, 0, len(p.dbHandle.shares)) for shareID := range p.dbHandle.shares { p.dbHandle.sharesIDs = append(p.dbHandle.sharesIDs, shareID) } sort.Strings(p.dbHandle.sharesIDs) } func (p *MemoryProvider) shareExistsInternal(shareID, username string) (Share, error) { if val, ok := p.dbHandle.shares[shareID]; ok { if username != "" && val.Username != username { return Share{}, util.NewRecordNotFoundError(fmt.Sprintf("Share %q does not exist", shareID)) } return val.getACopy(), nil } return Share{}, util.NewRecordNotFoundError(fmt.Sprintf("Share %q does not exist", shareID)) } func (p *MemoryProvider) shareExists(shareID, username string) (Share, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return Share{}, errMemoryProviderClosed } return p.shareExistsInternal(shareID, username) } func (p *MemoryProvider) addShare(share *Share) error { err := share.validate() if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err = p.shareExistsInternal(share.ShareID, share.Username) if err == nil { return fmt.Errorf("share %q already exists", share.ShareID) } if _, err := p.userExistsInternal(share.Username); err != nil { return util.NewValidationError(fmt.Sprintf("related user %q does not exists", share.Username)) } if !share.IsRestore { share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) share.UpdatedAt = share.CreatedAt share.LastUseAt = 0 share.UsedTokens = 0 } if share.CreatedAt == 0 { share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) } if share.UpdatedAt == 0 { share.UpdatedAt = share.CreatedAt } p.dbHandle.shares[share.ShareID] = share.getACopy() p.dbHandle.sharesIDs = append(p.dbHandle.sharesIDs, share.ShareID) sort.Strings(p.dbHandle.sharesIDs) return nil } func (p *MemoryProvider) updateShare(share *Share) error { err := share.validate() if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } s, err := p.shareExistsInternal(share.ShareID, share.Username) if err != nil { return err } if _, err := p.userExistsInternal(share.Username); err != nil { return util.NewValidationError(fmt.Sprintf("related user %q does not exists", share.Username)) } share.ID = s.ID share.ShareID = s.ShareID if !share.IsRestore { share.UsedTokens = s.UsedTokens share.CreatedAt = s.CreatedAt share.LastUseAt = s.LastUseAt share.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) } if share.CreatedAt == 0 { share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) } if share.UpdatedAt == 0 { share.UpdatedAt = share.CreatedAt } p.dbHandle.shares[share.ShareID] = share.getACopy() return nil } func (p *MemoryProvider) deleteShare(share Share) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err := p.shareExistsInternal(share.ShareID, share.Username) if err != nil { return err } delete(p.dbHandle.shares, share.ShareID) p.updateSharesOrdering() return nil } func (p *MemoryProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return []Share{}, errMemoryProviderClosed } if limit <= 0 { return []Share{}, nil } shares := make([]Share, 0, limit) itNum := 0 if order == OrderDESC { for i := len(p.dbHandle.sharesIDs) - 1; i >= 0; i-- { shareID := p.dbHandle.sharesIDs[i] s := p.dbHandle.shares[shareID] if s.Username != username { continue } itNum++ if itNum <= offset { continue } share := s.getACopy() share.HideConfidentialData() shares = append(shares, share) if len(shares) >= limit { break } } } else { for _, shareID := range p.dbHandle.sharesIDs { s := p.dbHandle.shares[shareID] if s.Username != username { continue } itNum++ if itNum <= offset { continue } share := s.getACopy() share.HideConfidentialData() shares = append(shares, share) if len(shares) >= limit { break } } } return shares, nil } func (p *MemoryProvider) dumpShares() ([]Share, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() shares := make([]Share, 0, len(p.dbHandle.shares)) if p.dbHandle.isClosed { return shares, errMemoryProviderClosed } for _, s := range p.dbHandle.shares { shares = append(shares, s) } return shares, nil } func (p *MemoryProvider) updateShareLastUse(shareID string, numTokens int) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } share, err := p.shareExistsInternal(shareID, "") if err != nil { return err } share.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now()) share.UsedTokens += numTokens p.dbHandle.shares[share.ShareID] = share return nil } func (p *MemoryProvider) getDefenderHosts(_ int64, _ int) ([]DefenderEntry, error) { return nil, ErrNotImplemented } func (p *MemoryProvider) getDefenderHostByIP(_ string, _ int64) (DefenderEntry, error) { return DefenderEntry{}, ErrNotImplemented } func (p *MemoryProvider) isDefenderHostBanned(_ string) (DefenderEntry, error) { return DefenderEntry{}, ErrNotImplemented } func (p *MemoryProvider) updateDefenderBanTime(_ string, _ int) error { return ErrNotImplemented } func (p *MemoryProvider) deleteDefenderHost(_ string) error { return ErrNotImplemented } func (p *MemoryProvider) addDefenderEvent(_ string, _ int) error { return ErrNotImplemented } func (p *MemoryProvider) setDefenderBanTime(_ string, _ int64) error { return ErrNotImplemented } func (p *MemoryProvider) cleanupDefender(_ int64) error { return ErrNotImplemented } func (p *MemoryProvider) addActiveTransfer(_ ActiveTransfer) error { return ErrNotImplemented } func (p *MemoryProvider) updateActiveTransferSizes(_, _, _ int64, _ string) error { return ErrNotImplemented } func (p *MemoryProvider) removeActiveTransfer(_ int64, _ string) error { return ErrNotImplemented } func (p *MemoryProvider) cleanupActiveTransfers(_ time.Time) error { return ErrNotImplemented } func (p *MemoryProvider) getActiveTransfers(_ time.Time) ([]ActiveTransfer, error) { return nil, ErrNotImplemented } func (p *MemoryProvider) addSharedSession(_ Session) error { return ErrNotImplemented } func (p *MemoryProvider) deleteSharedSession(_ string, _ SessionType) error { return ErrNotImplemented } func (p *MemoryProvider) getSharedSession(_ string, _ SessionType) (Session, error) { return Session{}, ErrNotImplemented } func (p *MemoryProvider) cleanupSharedSessions(_ SessionType, _ int64) error { return ErrNotImplemented } func (p *MemoryProvider) getEventActions(limit, offset int, order string, _ bool) ([]BaseEventAction, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } if limit <= 0 { return nil, nil } actions := make([]BaseEventAction, 0, limit) itNum := 0 if order == OrderASC { for _, name := range p.dbHandle.actionsNames { itNum++ if itNum <= offset { continue } a := p.dbHandle.actions[name] action := a.getACopy() action.PrepareForRendering() actions = append(actions, action) if len(actions) >= limit { break } } } else { for i := len(p.dbHandle.actionsNames) - 1; i >= 0; i-- { itNum++ if itNum <= offset { continue } name := p.dbHandle.actionsNames[i] a := p.dbHandle.actions[name] action := a.getACopy() action.PrepareForRendering() actions = append(actions, action) if len(actions) >= limit { break } } } return actions, nil } func (p *MemoryProvider) dumpEventActions() ([]BaseEventAction, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } actions := make([]BaseEventAction, 0, len(p.dbHandle.actions)) for _, name := range p.dbHandle.actionsNames { a := p.dbHandle.actions[name] action := a.getACopy() actions = append(actions, action) } return actions, nil } func (p *MemoryProvider) eventActionExists(name string) (BaseEventAction, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return BaseEventAction{}, errMemoryProviderClosed } return p.actionExistsInternal(name) } func (p *MemoryProvider) addEventAction(action *BaseEventAction) error { err := action.validate() if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err = p.actionExistsInternal(action.Name) if err == nil { return util.NewI18nError( fmt.Errorf("%w: event action %q already exists", ErrDuplicatedKey, action.Name), util.I18nErrorDuplicatedName, ) } action.ID = p.getNextActionID() action.Rules = nil p.dbHandle.actions[action.Name] = action.getACopy() p.dbHandle.actionsNames = append(p.dbHandle.actionsNames, action.Name) sort.Strings(p.dbHandle.actionsNames) return nil } func (p *MemoryProvider) updateEventAction(action *BaseEventAction) error { err := action.validate() if err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } oldAction, err := p.actionExistsInternal(action.Name) if err != nil { return fmt.Errorf("event action %s does not exist", action.Name) } action.ID = oldAction.ID action.Name = oldAction.Name action.Rules = nil if len(oldAction.Rules) > 0 { var relatedRules []string for _, ruleName := range oldAction.Rules { rule, err := p.ruleExistsInternal(ruleName) if err == nil { relatedRules = append(relatedRules, ruleName) rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.rules[ruleName] = rule setLastRuleUpdate() } } action.Rules = relatedRules } p.dbHandle.actions[action.Name] = action.getACopy() return nil } func (p *MemoryProvider) deleteEventAction(action BaseEventAction) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } oldAction, err := p.actionExistsInternal(action.Name) if err != nil { return fmt.Errorf("event action %s does not exist", action.Name) } if len(oldAction.Rules) > 0 { return util.NewValidationError(fmt.Sprintf("action %s is referenced, it cannot be removed", oldAction.Name)) } delete(p.dbHandle.actions, action.Name) // this could be more efficient p.dbHandle.actionsNames = make([]string, 0, len(p.dbHandle.actions)) for name := range p.dbHandle.actions { p.dbHandle.actionsNames = append(p.dbHandle.actionsNames, name) } sort.Strings(p.dbHandle.actionsNames) return nil } func (p *MemoryProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } if limit <= 0 { return nil, nil } itNum := 0 rules := make([]EventRule, 0, limit) if order == OrderASC { for _, name := range p.dbHandle.rulesNames { itNum++ if itNum <= offset { continue } r := p.dbHandle.rules[name] rule := r.getACopy() p.addActionsToRule(&rule) rule.PrepareForRendering() rules = append(rules, rule) if len(rules) >= limit { break } } } else { for i := len(p.dbHandle.rulesNames) - 1; i >= 0; i-- { itNum++ if itNum <= offset { continue } name := p.dbHandle.rulesNames[i] r := p.dbHandle.rules[name] rule := r.getACopy() p.addActionsToRule(&rule) rule.PrepareForRendering() rules = append(rules, rule) if len(rules) >= limit { break } } } return rules, nil } func (p *MemoryProvider) dumpEventRules() ([]EventRule, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } rules := make([]EventRule, 0, len(p.dbHandle.rules)) for _, name := range p.dbHandle.rulesNames { r := p.dbHandle.rules[name] rule := r.getACopy() p.addActionsToRule(&rule) rules = append(rules, rule) } return rules, nil } func (p *MemoryProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { if getLastRuleUpdate() < after { return nil, nil } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } rules := make([]EventRule, 0, 10) for _, name := range p.dbHandle.rulesNames { r := p.dbHandle.rules[name] if r.UpdatedAt < after { continue } rule := r.getACopy() p.addActionsToRule(&rule) rules = append(rules, rule) } return rules, nil } func (p *MemoryProvider) eventRuleExists(name string) (EventRule, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return EventRule{}, errMemoryProviderClosed } rule, err := p.ruleExistsInternal(name) if err != nil { return rule, err } p.addActionsToRule(&rule) return rule, nil } func (p *MemoryProvider) addEventRule(rule *EventRule) error { if err := rule.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err := p.ruleExistsInternal(rule.Name) if err == nil { return util.NewI18nError( fmt.Errorf("%w: event rule %q already exists", ErrDuplicatedKey, rule.Name), util.I18nErrorDuplicatedName, ) } rule.ID = p.getNextRuleID() rule.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) rule.UpdatedAt = rule.CreatedAt var mappedActions []string for idx := range rule.Actions { if err := p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name); err != nil { // try to remove action mapping for _, a := range mappedActions { p.removeRuleFromActionMapping(rule.Name, a) } return err } mappedActions = append(mappedActions, rule.Actions[idx].Name) } sort.Slice(rule.Actions, func(i, j int) bool { return rule.Actions[i].Order < rule.Actions[j].Order }) p.dbHandle.rules[rule.Name] = rule.getACopy() p.dbHandle.rulesNames = append(p.dbHandle.rulesNames, rule.Name) sort.Strings(p.dbHandle.rulesNames) setLastRuleUpdate() return nil } func (p *MemoryProvider) updateEventRule(rule *EventRule) error { if err := rule.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } oldRule, err := p.ruleExistsInternal(rule.Name) if err != nil { return err } for idx := range oldRule.Actions { p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name) } for idx := range rule.Actions { if err = p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name); err != nil { // try to add old mapping for _, oldAction := range oldRule.Actions { if errRollback := p.addRuleToActionMapping(oldRule.Name, oldAction.Name); errRollback != nil { providerLog(logger.LevelError, "unable to rollback old action mapping %q for rule %q, error: %v", oldAction.Name, oldRule.Name, errRollback) } } return err } } rule.ID = oldRule.ID rule.CreatedAt = oldRule.CreatedAt rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) sort.Slice(rule.Actions, func(i, j int) bool { return rule.Actions[i].Order < rule.Actions[j].Order }) p.dbHandle.rules[rule.Name] = rule.getACopy() setLastRuleUpdate() return nil } func (p *MemoryProvider) deleteEventRule(rule EventRule, _ bool) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } oldRule, err := p.ruleExistsInternal(rule.Name) if err != nil { return err } if len(oldRule.Actions) > 0 { for idx := range oldRule.Actions { p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name) } } delete(p.dbHandle.rules, rule.Name) p.dbHandle.rulesNames = make([]string, 0, len(p.dbHandle.rules)) for name := range p.dbHandle.rules { p.dbHandle.rulesNames = append(p.dbHandle.rulesNames, name) } sort.Strings(p.dbHandle.rulesNames) setLastRuleUpdate() return nil } func (*MemoryProvider) getTaskByName(_ string) (Task, error) { return Task{}, ErrNotImplemented } func (*MemoryProvider) addTask(_ string) error { return ErrNotImplemented } func (*MemoryProvider) updateTask(_ string, _ int64) error { return ErrNotImplemented } func (*MemoryProvider) updateTaskTimestamp(_ string) error { return ErrNotImplemented } func (*MemoryProvider) addNode() error { return ErrNotImplemented } func (*MemoryProvider) getNodeByName(_ string) (Node, error) { return Node{}, ErrNotImplemented } func (*MemoryProvider) getNodes() ([]Node, error) { return nil, ErrNotImplemented } func (*MemoryProvider) updateNodeTimestamp() error { return ErrNotImplemented } func (*MemoryProvider) cleanupNodes() error { return ErrNotImplemented } func (p *MemoryProvider) roleExists(name string) (Role, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return Role{}, errMemoryProviderClosed } role, err := p.roleExistsInternal(name) if err != nil { return role, err } return role, nil } func (p *MemoryProvider) addRole(role *Role) error { if err := role.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err := p.roleExistsInternal(role.Name) if err == nil { return util.NewI18nError( fmt.Errorf("%w: role %q already exists", ErrDuplicatedKey, role.Name), util.I18nErrorDuplicatedName, ) } role.ID = p.getNextRoleID() role.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) role.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) role.Users = nil role.Admins = nil p.dbHandle.roles[role.Name] = role.getACopy() p.dbHandle.roleNames = append(p.dbHandle.roleNames, role.Name) sort.Strings(p.dbHandle.roleNames) return nil } func (p *MemoryProvider) updateRole(role *Role) error { if err := role.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } oldRole, err := p.roleExistsInternal(role.Name) if err != nil { return err } role.ID = oldRole.ID role.CreatedAt = oldRole.CreatedAt role.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) role.Users = oldRole.Users role.Admins = oldRole.Admins p.dbHandle.roles[role.Name] = role.getACopy() return nil } func (p *MemoryProvider) deleteRole(role Role) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } oldRole, err := p.roleExistsInternal(role.Name) if err != nil { return err } if len(oldRole.Admins) > 0 { return util.NewValidationError(fmt.Sprintf("the role %q is referenced, it cannot be removed", oldRole.Name)) } for _, username := range oldRole.Users { user, err := p.userExistsInternal(username) if err != nil { continue } if user.Role == role.Name { user.Role = "" p.dbHandle.users[username] = user } else { providerLog(logger.LevelError, "user %q does not have the expected role %q, actual %q", username, role.Name, user.Role) } } delete(p.dbHandle.roles, role.Name) p.dbHandle.roleNames = make([]string, 0, len(p.dbHandle.roles)) for name := range p.dbHandle.roles { p.dbHandle.roleNames = append(p.dbHandle.roleNames, name) } sort.Strings(p.dbHandle.roleNames) return nil } func (p *MemoryProvider) getRoles(limit int, offset int, order string, _ bool) ([]Role, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } if limit <= 0 { return nil, nil } roles := make([]Role, 0, 10) itNum := 0 if order == OrderASC { for _, name := range p.dbHandle.roleNames { itNum++ if itNum <= offset { continue } r := p.dbHandle.roles[name] role := r.getACopy() roles = append(roles, role) if len(roles) >= limit { break } } } else { for i := len(p.dbHandle.roleNames) - 1; i >= 0; i-- { itNum++ if itNum <= offset { continue } name := p.dbHandle.roleNames[i] r := p.dbHandle.roles[name] role := r.getACopy() roles = append(roles, role) if len(roles) >= limit { break } } } return roles, nil } func (p *MemoryProvider) dumpRoles() ([]Role, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } roles := make([]Role, 0, len(p.dbHandle.roles)) for _, name := range p.dbHandle.roleNames { r := p.dbHandle.roles[name] roles = append(roles, r.getACopy()) } return roles, nil } func (p *MemoryProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return IPListEntry{}, errMemoryProviderClosed } entry, err := p.ipListEntryExistsInternal(&IPListEntry{IPOrNet: ipOrNet, Type: listType}) if err != nil { return entry, err } entry.PrepareForRendering() return entry, nil } func (p *MemoryProvider) addIPListEntry(entry *IPListEntry) error { if err := entry.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err := p.ipListEntryExistsInternal(entry) if err == nil { return util.NewI18nError( fmt.Errorf("%w: entry %q already exists", ErrDuplicatedKey, entry.IPOrNet), util.I18nErrorDuplicatedIPNet, ) } entry.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) entry.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.ipListEntries[entry.getKey()] = entry.getACopy() p.dbHandle.ipListEntriesKeys = append(p.dbHandle.ipListEntriesKeys, entry.getKey()) sort.Strings(p.dbHandle.ipListEntriesKeys) return nil } func (p *MemoryProvider) updateIPListEntry(entry *IPListEntry) error { if err := entry.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } oldEntry, err := p.ipListEntryExistsInternal(entry) if err != nil { return err } entry.CreatedAt = oldEntry.CreatedAt entry.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.ipListEntries[entry.getKey()] = entry.getACopy() return nil } func (p *MemoryProvider) deleteIPListEntry(entry IPListEntry, _ bool) error { if err := entry.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } _, err := p.ipListEntryExistsInternal(&entry) if err != nil { return err } delete(p.dbHandle.ipListEntries, entry.getKey()) p.dbHandle.ipListEntriesKeys = make([]string, 0, len(p.dbHandle.ipListEntries)) for k := range p.dbHandle.ipListEntries { p.dbHandle.ipListEntriesKeys = append(p.dbHandle.ipListEntriesKeys, k) } sort.Strings(p.dbHandle.ipListEntriesKeys) return nil } func (p *MemoryProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } entries := make([]IPListEntry, 0, 15) if order == OrderASC { for _, k := range p.dbHandle.ipListEntriesKeys { e := p.dbHandle.ipListEntries[k] if e.Type == listType && e.satisfySearchConstraints(filter, from, order) { entry := e.getACopy() entry.PrepareForRendering() entries = append(entries, entry) if limit > 0 && len(entries) >= limit { break } } } } else { for i := len(p.dbHandle.ipListEntriesKeys) - 1; i >= 0; i-- { e := p.dbHandle.ipListEntries[p.dbHandle.ipListEntriesKeys[i]] if e.Type == listType && e.satisfySearchConstraints(filter, from, order) { entry := e.getACopy() entry.PrepareForRendering() entries = append(entries, entry) if limit > 0 && len(entries) >= limit { break } } } } return entries, nil } func (p *MemoryProvider) getRecentlyUpdatedIPListEntries(_ int64) ([]IPListEntry, error) { return nil, ErrNotImplemented } func (p *MemoryProvider) dumpIPListEntries() ([]IPListEntry, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } if count := len(p.dbHandle.ipListEntriesKeys); count > ipListMemoryLimit { providerLog(logger.LevelInfo, "IP lists excluded from dump, too many entries: %d", count) return nil, nil } entries := make([]IPListEntry, 0, len(p.dbHandle.ipListEntries)) for _, k := range p.dbHandle.ipListEntriesKeys { e := p.dbHandle.ipListEntries[k] entry := e.getACopy() entry.PrepareForRendering() entries = append(entries, entry) } return entries, nil } func (p *MemoryProvider) countIPListEntries(listType IPListType) (int64, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return 0, errMemoryProviderClosed } if listType == 0 { return int64(len(p.dbHandle.ipListEntriesKeys)), nil } var count int64 for _, k := range p.dbHandle.ipListEntriesKeys { e := p.dbHandle.ipListEntries[k] if e.Type == listType { count++ } } return count, nil } func (p *MemoryProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return nil, errMemoryProviderClosed } entries := make([]IPListEntry, 0, 3) ipAddr, err := netip.ParseAddr(ip) if err != nil { return entries, fmt.Errorf("invalid ip address %s", ip) } var netType int var ipBytes []byte if ipAddr.Is4() || ipAddr.Is4In6() { netType = ipTypeV4 as4 := ipAddr.As4() ipBytes = as4[:] } else { netType = ipTypeV6 as16 := ipAddr.As16() ipBytes = as16[:] } for _, k := range p.dbHandle.ipListEntriesKeys { e := p.dbHandle.ipListEntries[k] if e.Type == listType && e.IPType == netType && bytes.Compare(ipBytes, e.First) >= 0 && bytes.Compare(ipBytes, e.Last) <= 0 { entry := e.getACopy() entry.PrepareForRendering() entries = append(entries, entry) } } return entries, nil } func (p *MemoryProvider) getConfigs() (Configs, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return Configs{}, errMemoryProviderClosed } return p.dbHandle.configs.getACopy(), nil } func (p *MemoryProvider) setConfigs(configs *Configs) error { if err := configs.validate(); err != nil { return err } p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } p.dbHandle.configs = configs.getACopy() return nil } func (p *MemoryProvider) setFirstDownloadTimestamp(username string) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { return err } if user.FirstDownload > 0 { return util.NewGenericError(fmt.Sprintf("first download already set to %s", util.GetTimeFromMsecSinceEpoch(user.FirstDownload))) } user.FirstDownload = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.users[user.Username] = user return nil } func (p *MemoryProvider) setFirstUploadTimestamp(username string) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { return errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { return err } if user.FirstUpload > 0 { return util.NewGenericError(fmt.Sprintf("first upload already set to %s", util.GetTimeFromMsecSinceEpoch(user.FirstUpload))) } user.FirstUpload = util.GetTimeAsMsSinceEpoch(time.Now()) p.dbHandle.users[user.Username] = user return nil } func (p *MemoryProvider) getNextID() int64 { nextID := int64(1) for _, v := range p.dbHandle.users { if v.ID >= nextID { nextID = v.ID + 1 } } return nextID } func (p *MemoryProvider) getNextFolderID() int64 { nextID := int64(1) for _, v := range p.dbHandle.vfolders { if v.ID >= nextID { nextID = v.ID + 1 } } return nextID } func (p *MemoryProvider) getNextAdminID() int64 { nextID := int64(1) for _, a := range p.dbHandle.admins { if a.ID >= nextID { nextID = a.ID + 1 } } return nextID } func (p *MemoryProvider) getNextGroupID() int64 { nextID := int64(1) for _, g := range p.dbHandle.groups { if g.ID >= nextID { nextID = g.ID + 1 } } return nextID } func (p *MemoryProvider) getNextActionID() int64 { nextID := int64(1) for _, a := range p.dbHandle.actions { if a.ID >= nextID { nextID = a.ID + 1 } } return nextID } func (p *MemoryProvider) getNextRuleID() int64 { nextID := int64(1) for _, r := range p.dbHandle.rules { if r.ID >= nextID { nextID = r.ID + 1 } } return nextID } func (p *MemoryProvider) getNextRoleID() int64 { nextID := int64(1) for _, r := range p.dbHandle.roles { if r.ID >= nextID { nextID = r.ID + 1 } } return nextID } func (p *MemoryProvider) clear() { p.dbHandle.Lock() defer p.dbHandle.Unlock() p.dbHandle.usernames = []string{} p.dbHandle.users = make(map[string]User) p.dbHandle.groupnames = []string{} p.dbHandle.groups = map[string]Group{} p.dbHandle.vfoldersNames = []string{} p.dbHandle.vfolders = make(map[string]vfs.BaseVirtualFolder) p.dbHandle.admins = make(map[string]Admin) p.dbHandle.adminsUsernames = []string{} p.dbHandle.apiKeys = make(map[string]APIKey) p.dbHandle.apiKeysIDs = []string{} p.dbHandle.shares = make(map[string]Share) p.dbHandle.sharesIDs = []string{} p.dbHandle.actions = map[string]BaseEventAction{} p.dbHandle.actionsNames = []string{} p.dbHandle.rules = map[string]EventRule{} p.dbHandle.rulesNames = []string{} p.dbHandle.roles = map[string]Role{} p.dbHandle.roleNames = []string{} p.dbHandle.ipListEntries = map[string]IPListEntry{} p.dbHandle.ipListEntriesKeys = []string{} p.dbHandle.configs = Configs{} } func (p *MemoryProvider) reloadConfig() error { if p.dbHandle.configFile == "" { providerLog(logger.LevelDebug, "no dump configuration file defined") return nil } providerLog(logger.LevelDebug, "loading dump from file: %q", p.dbHandle.configFile) fi, err := os.Stat(p.dbHandle.configFile) if err != nil { providerLog(logger.LevelError, "error loading dump: %v", err) return err } if fi.Size() == 0 { err = errors.New("dump configuration file is invalid, its size must be > 0") providerLog(logger.LevelError, "error loading dump: %v", err) return err } if fi.Size() > 20971520 { err = errors.New("dump configuration file is invalid, its size must be <= 20971520 bytes") providerLog(logger.LevelError, "error loading dump: %v", err) return err } content, err := os.ReadFile(p.dbHandle.configFile) if err != nil { providerLog(logger.LevelError, "error loading dump: %v", err) return err } dump, err := ParseDumpData(content) if err != nil { providerLog(logger.LevelError, "error loading dump: %v", err) return err } return p.restoreDump(&dump) } func (p *MemoryProvider) restoreDump(dump *BackupData) error { p.clear() if err := p.restoreConfigs(dump); err != nil { return err } if err := p.restoreIPListEntries(dump); err != nil { return err } if err := p.restoreRoles(dump); err != nil { return err } if err := p.restoreFolders(dump); err != nil { return err } if err := p.restoreGroups(dump); err != nil { return err } if err := p.restoreUsers(dump); err != nil { return err } if err := p.restoreAdmins(dump); err != nil { return err } if err := p.restoreAPIKeys(dump); err != nil { return err } if err := p.restoreShares(dump); err != nil { return err } if err := p.restoreEventActions(dump); err != nil { return err } if err := p.restoreEventRules(dump); err != nil { return err } providerLog(logger.LevelDebug, "config loaded from file: %q", p.dbHandle.configFile) return nil } func (p *MemoryProvider) restoreEventActions(dump *BackupData) error { for idx := range dump.EventActions { action := dump.EventActions[idx] a, err := p.eventActionExists(action.Name) if err == nil { action.ID = a.ID err = UpdateEventAction(&action, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating event action %q: %v", action.Name, err) return err } } else { err = AddEventAction(&action, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding event action %q: %v", action.Name, err) return err } } } return nil } func (p *MemoryProvider) restoreEventRules(dump *BackupData) error { for idx := range dump.EventRules { rule := dump.EventRules[idx] r, err := p.eventRuleExists(rule.Name) if dump.Version < 15 { rule.Status = 1 } if err == nil { rule.ID = r.ID err = UpdateEventRule(&rule, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating event rule %q: %v", rule.Name, err) return err } } else { err = AddEventRule(&rule, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding event rule %q: %v", rule.Name, err) return err } } } return nil } func (p *MemoryProvider) restoreShares(dump *BackupData) error { for idx := range dump.Shares { share := dump.Shares[idx] s, err := p.shareExists(share.ShareID, "") share.IsRestore = true if err == nil { share.ID = s.ID err = UpdateShare(&share, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating share %q: %v", share.ShareID, err) return err } } else { err = AddShare(&share, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding share %q: %v", share.ShareID, err) return err } } } return nil } func (p *MemoryProvider) restoreAPIKeys(dump *BackupData) error { for idx := range dump.APIKeys { apiKey := dump.APIKeys[idx] if apiKey.Key == "" { return fmt.Errorf("cannot restore an empty API key: %+v", apiKey) } k, err := p.apiKeyExists(apiKey.KeyID) if err == nil { apiKey.ID = k.ID err = UpdateAPIKey(&apiKey, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating API key %q: %v", apiKey.KeyID, err) return err } } else { err = AddAPIKey(&apiKey, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding API key %q: %v", apiKey.KeyID, err) return err } } } return nil } func (p *MemoryProvider) restoreAdmins(dump *BackupData) error { for idx := range dump.Admins { admin := dump.Admins[idx] admin.Username = config.convertName(admin.Username) a, err := p.adminExists(admin.Username) if err == nil { admin.ID = a.ID err = UpdateAdmin(&admin, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating admin %q: %v", admin.Username, err) return err } } else { err = AddAdmin(&admin, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding admin %q: %v", admin.Username, err) return err } } } return nil } func (p *MemoryProvider) restoreConfigs(dump *BackupData) error { if dump.Configs != nil && dump.Configs.UpdatedAt > 0 { return UpdateConfigs(dump.Configs, ActionExecutorSystem, "", "") } return nil } func (p *MemoryProvider) restoreIPListEntries(dump *BackupData) error { for idx := range dump.IPLists { entry := dump.IPLists[idx] _, err := p.ipListEntryExists(entry.IPOrNet, entry.Type) if err == nil { err = UpdateIPListEntry(&entry, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating IP list entry %q: %v", entry.getName(), err) return err } } else { err = AddIPListEntry(&entry, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding IP list entry %q: %v", entry.getName(), err) return err } } } return nil } func (p *MemoryProvider) restoreRoles(dump *BackupData) error { for idx := range dump.Roles { role := dump.Roles[idx] role.Name = config.convertName(role.Name) r, err := p.roleExists(role.Name) if err == nil { role.ID = r.ID err = UpdateRole(&role, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating role %q: %v", role.Name, err) return err } } else { role.Admins = nil role.Users = nil err = AddRole(&role, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding role %q: %v", role.Name, err) return err } } } return nil } func (p *MemoryProvider) restoreGroups(dump *BackupData) error { for idx := range dump.Groups { group := dump.Groups[idx] group.Name = config.convertName(group.Name) g, err := p.groupExists(group.Name) if err == nil { group.ID = g.ID err = UpdateGroup(&group, g.Users, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating group %q: %v", group.Name, err) return err } } else { group.Users = nil err = AddGroup(&group, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding group %q: %v", group.Name, err) return err } } } return nil } func (p *MemoryProvider) restoreFolders(dump *BackupData) error { for idx := range dump.Folders { folder := dump.Folders[idx] folder.Name = config.convertName(folder.Name) f, err := p.getFolderByName(folder.Name) if err == nil { folder.ID = f.ID err = UpdateFolder(&folder, f.Users, f.Groups, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating folder %q: %v", folder.Name, err) return err } } else { folder.Users = nil err = AddFolder(&folder, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding folder %q: %v", folder.Name, err) return err } } } return nil } func (p *MemoryProvider) restoreUsers(dump *BackupData) error { for idx := range dump.Users { user := dump.Users[idx] user.Username = config.convertName(user.Username) u, err := p.userExists(user.Username, "") if err == nil { user.ID = u.ID err = UpdateUser(&user, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error updating user %q: %v", user.Username, err) return err } } else { err = AddUser(&user, ActionExecutorSystem, "", "") if err != nil { providerLog(logger.LevelError, "error adding user %q: %v", user.Username, err) return err } } } return nil } // initializeDatabase does nothing, no initilization is needed for memory provider func (p *MemoryProvider) initializeDatabase() error { return ErrNoInitRequired } func (p *MemoryProvider) migrateDatabase() error { return ErrNoInitRequired } func (p *MemoryProvider) revertDatabase(_ int) error { return errors.New("memory provider does not store data, revert not possible") } func (p *MemoryProvider) resetDatabase() error { return errors.New("memory provider does not store data, reset not possible") } ================================================ FILE: internal/dataprovider/mysql.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !nomysql package dataprovider import ( "context" "crypto/tls" "crypto/x509" "database/sql" "errors" "fmt" "os" "path/filepath" "strings" "time" "github.com/go-sql-driver/mysql" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( mysqlResetSQL = "DROP TABLE IF EXISTS `{{api_keys}}` CASCADE;" + "DROP TABLE IF EXISTS `{{users_folders_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{users_groups_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{admins_groups_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{groups_folders_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{shares_groups_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{admins}}` CASCADE;" + "DROP TABLE IF EXISTS `{{folders}}` CASCADE;" + "DROP TABLE IF EXISTS `{{shares}}` CASCADE;" + "DROP TABLE IF EXISTS `{{users}}` CASCADE;" + "DROP TABLE IF EXISTS `{{groups}}` CASCADE;" + "DROP TABLE IF EXISTS `{{defender_events}}` CASCADE;" + "DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" + "DROP TABLE IF EXISTS `{{active_transfers}}` CASCADE;" + "DROP TABLE IF EXISTS `{{shared_sessions}}` CASCADE;" + "DROP TABLE IF EXISTS `{{rules_actions_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{events_actions}}` CASCADE;" + "DROP TABLE IF EXISTS `{{events_rules}}` CASCADE;" + "DROP TABLE IF EXISTS `{{tasks}}` CASCADE;" + "DROP TABLE IF EXISTS `{{nodes}}` CASCADE;" + "DROP TABLE IF EXISTS `{{roles}}` CASCADE;" + "DROP TABLE IF EXISTS `{{ip_lists}}` CASCADE;" + "DROP TABLE IF EXISTS `{{configs}}` CASCADE;" + "DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;" mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" + "CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + "`description` varchar(512) NULL, `password` varchar(255) NOT NULL, `email` varchar(255) NULL, `status` integer NOT NULL, " + "`permissions` longtext NOT NULL, `filters` longtext NULL, `additional_info` longtext NULL, `last_login` bigint NOT NULL, " + "`role_id` integer NULL, `created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" + "CREATE TABLE `{{active_transfers}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`connection_id` varchar(100) NOT NULL, `transfer_id` bigint NOT NULL, `transfer_type` integer NOT NULL, " + "`username` varchar(255) NOT NULL, `folder_name` varchar(255) NULL, `ip` varchar(50) NOT NULL, " + "`truncated_size` bigint NOT NULL, `current_ul_size` bigint NOT NULL, `current_dl_size` bigint NOT NULL, " + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" + "CREATE TABLE `{{defender_hosts}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`ip` varchar(50) NOT NULL UNIQUE, `ban_time` bigint NOT NULL, `updated_at` bigint NOT NULL);" + "CREATE TABLE `{{defender_events}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`date_time` bigint NOT NULL, `score` integer NOT NULL, `host_id` bigint NOT NULL);" + "ALTER TABLE `{{defender_events}}` ADD CONSTRAINT `{{prefix}}defender_events_host_id_fk_defender_hosts_id` " + "FOREIGN KEY (`host_id`) REFERENCES `{{defender_hosts}}` (`id`) ON DELETE CASCADE;" + "CREATE TABLE `{{folders}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " + "`description` varchar(512) NULL, `path` longtext NULL, `used_quota_size` bigint NOT NULL, " + "`used_quota_files` integer NOT NULL, `last_quota_update` bigint NOT NULL, `filesystem` longtext NULL);" + "CREATE TABLE `{{groups}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " + "`updated_at` bigint NOT NULL, `user_settings` longtext NULL);" + "CREATE TABLE `{{shared_sessions}}` (`key` varchar(128) NOT NULL, `type` integer NOT NULL, `data` longtext NOT NULL, " + "`timestamp` bigint NOT NULL, PRIMARY KEY (`key`, `type`));" + "CREATE TABLE `{{users}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + "`status` integer NOT NULL, `expiration_date` bigint NOT NULL, `description` varchar(512) NULL, `password` longtext NULL, " + "`public_keys` longtext NULL, `home_dir` longtext NOT NULL, `uid` bigint NOT NULL, `gid` bigint NOT NULL, " + "`max_sessions` integer NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, " + "`permissions` longtext NOT NULL, `used_quota_size` bigint NOT NULL, `used_quota_files` integer NOT NULL, " + "`last_quota_update` bigint NOT NULL, `upload_bandwidth` integer NOT NULL, `download_bandwidth` integer NOT NULL, " + "`last_login` bigint NOT NULL, `filters` longtext NULL, `filesystem` longtext NULL, `additional_info` longtext NULL, " + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `email` varchar(255) NULL, " + "`upload_data_transfer` integer NOT NULL, `download_data_transfer` integer NOT NULL, " + "`total_data_transfer` integer NOT NULL, `used_upload_data_transfer` bigint NOT NULL, " + "`used_download_data_transfer` bigint NOT NULL, `deleted_at` bigint NOT NULL, `first_download` bigint NOT NULL, " + "`first_upload` bigint NOT NULL, `last_password_change` bigint NOT NULL, `role_id` integer NULL);" + "CREATE TABLE `{{groups_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`group_id` integer NOT NULL, `folder_id` integer NOT NULL, " + "`virtual_path` longtext NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, `sort_order` integer NOT NULL);" + "CREATE TABLE `{{users_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`user_id` integer NOT NULL, `group_id` integer NOT NULL, `group_type` integer NOT NULL, `sort_order` integer NOT NULL);" + "CREATE TABLE `{{users_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `virtual_path` longtext NOT NULL, " + "`quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, `folder_id` integer NOT NULL, `user_id` integer NOT NULL, `sort_order` integer NOT NULL);" + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_folder_mapping` " + "UNIQUE (`user_id`, `folder_id`);" + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_user_id_fk_users_id` " + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_folder_id_fk_folders_id` " + "FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" + "CREATE INDEX `{{prefix}}users_folders_mapping_sort_order_idx` ON `{{users_folders_mapping}}` (`sort_order`);" + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_group_mapping` UNIQUE (`user_id`, `group_id`);" + "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_group_folder_mapping` UNIQUE (`group_id`, `folder_id`);" + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_group_id_fk_groups_id` " + "FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE NO ACTION;" + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_user_id_fk_users_id` " + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE; " + "CREATE INDEX `{{prefix}}users_groups_mapping_sort_order_idx` ON `{{users_groups_mapping}}` (`sort_order`);" + "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_folder_id_fk_folders_id` " + "FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" + "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_group_id_fk_groups_id` " + "FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" + "CREATE INDEX `{{prefix}}groups_folders_mapping_sort_order_idx` ON `{{groups_folders_mapping}}` (`sort_order`); " + "CREATE TABLE `{{shares}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`share_id` varchar(60) NOT NULL UNIQUE, `name` varchar(255) NOT NULL, `description` varchar(512) NULL, " + "`scope` integer NOT NULL, `paths` longtext NOT NULL, `created_at` bigint NOT NULL, " + "`updated_at` bigint NOT NULL, `last_use_at` bigint NOT NULL, `expires_at` bigint NOT NULL, " + "`password` longtext NULL, `max_tokens` integer NOT NULL, `used_tokens` integer NOT NULL, " + "`allow_from` longtext NULL, `options` longtext NULL, `user_id` integer NOT NULL);" + "ALTER TABLE `{{shares}}` ADD CONSTRAINT `{{prefix}}shares_user_id_fk_users_id` " + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" + "CREATE TABLE `{{api_keys}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL, `key_id` varchar(50) NOT NULL UNIQUE," + "`api_key` varchar(255) NOT NULL UNIQUE, `scope` integer NOT NULL, `created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `last_use_at` bigint NOT NULL, " + "`expires_at` bigint NOT NULL, `description` longtext NULL, `admin_id` integer NULL, `user_id` integer NULL);" + "ALTER TABLE `{{api_keys}}` ADD CONSTRAINT `{{prefix}}api_keys_admin_id_fk_admins_id` FOREIGN KEY (`admin_id`) REFERENCES `{{admins}}` (`id`) ON DELETE CASCADE;" + "ALTER TABLE `{{api_keys}}` ADD CONSTRAINT `{{prefix}}api_keys_user_id_fk_users_id` FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" + "CREATE TABLE `{{events_rules}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`name` varchar(255) NOT NULL UNIQUE, `status` integer NOT NULL, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " + "`updated_at` bigint NOT NULL, `trigger` integer NOT NULL, `conditions` longtext NOT NULL, `deleted_at` bigint NOT NULL);" + "CREATE TABLE `{{events_actions}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `type` integer NOT NULL, " + "`options` longtext NOT NULL);" + "CREATE TABLE `{{rules_actions_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`rule_id` integer NOT NULL, `action_id` integer NOT NULL, `order` integer NOT NULL, `options` longtext NOT NULL);" + "CREATE TABLE `{{tasks}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " + "`updated_at` bigint NOT NULL, `version` bigint NOT NULL);" + "ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}unique_rule_action_mapping` UNIQUE (`rule_id`, `action_id`);" + "ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}rules_actions_mapping_rule_id_fk_events_rules_id` " + "FOREIGN KEY (`rule_id`) REFERENCES `{{events_rules}}` (`id`) ON DELETE CASCADE;" + "ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}rules_actions_mapping_action_id_fk_events_targets_id` " + "FOREIGN KEY (`action_id`) REFERENCES `{{events_actions}}` (`id`) ON DELETE NO ACTION;" + "CREATE TABLE `{{admins_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + " `admin_id` integer NOT NULL, `group_id` integer NOT NULL, `options` longtext NOT NULL, `sort_order` integer NOT NULL);" + "ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_admin_group_mapping` " + "UNIQUE (`admin_id`, `group_id`);" + "ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}admins_groups_mapping_admin_id_fk_admins_id` " + "FOREIGN KEY (`admin_id`) REFERENCES `{{admins}}` (`id`) ON DELETE CASCADE;" + "ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}admins_groups_mapping_group_id_fk_groups_id` " + "FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" + "CREATE INDEX `{{prefix}}admins_groups_mapping_sort_order_idx` ON `{{admins_groups_mapping}}` (`sort_order`); " + "CREATE TABLE `{{nodes}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "`name` varchar(255) NOT NULL UNIQUE, `data` longtext NOT NULL, `created_at` bigint NOT NULL, " + "`updated_at` bigint NOT NULL);" + "CREATE TABLE `{{roles}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " + "`description` varchar(512) NULL, `created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" + "ALTER TABLE `{{admins}}` ADD CONSTRAINT `{{prefix}}admins_role_id_fk_roles_id` FOREIGN KEY (`role_id`) " + "REFERENCES `{{roles}}`(`id`) ON DELETE NO ACTION;" + "ALTER TABLE `{{users}}` ADD CONSTRAINT `{{prefix}}users_role_id_fk_roles_id` FOREIGN KEY (`role_id`) " + "REFERENCES `{{roles}}`(`id`) ON DELETE SET NULL;" + "CREATE TABLE `{{ip_lists}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, `type` integer NOT NULL, " + "`ipornet` varchar(50) NOT NULL, `mode` integer NOT NULL, `description` varchar(512) NULL, " + "`first` VARBINARY(16) NOT NULL, `last` VARBINARY(16) NOT NULL, `ip_type` integer NOT NULL, `protocols` integer NOT NULL, " + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `deleted_at` bigint NOT NULL);" + "ALTER TABLE `{{ip_lists}}` ADD CONSTRAINT `{{prefix}}unique_ipornet_type_mapping` UNIQUE (`type`, `ipornet`);" + "CREATE TABLE `{{configs}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `configs` longtext NOT NULL);" + "INSERT INTO {{configs}} (configs) VALUES ('{}');" + "CREATE INDEX `{{prefix}}users_updated_at_idx` ON `{{users}}` (`updated_at`);" + "CREATE INDEX `{{prefix}}users_deleted_at_idx` ON `{{users}}` (`deleted_at`);" + "CREATE INDEX `{{prefix}}defender_hosts_updated_at_idx` ON `{{defender_hosts}}` (`updated_at`);" + "CREATE INDEX `{{prefix}}defender_hosts_ban_time_idx` ON `{{defender_hosts}}` (`ban_time`);" + "CREATE INDEX `{{prefix}}defender_events_date_time_idx` ON `{{defender_events}}` (`date_time`);" + "CREATE INDEX `{{prefix}}active_transfers_connection_id_idx` ON `{{active_transfers}}` (`connection_id`);" + "CREATE INDEX `{{prefix}}active_transfers_transfer_id_idx` ON `{{active_transfers}}` (`transfer_id`);" + "CREATE INDEX `{{prefix}}active_transfers_updated_at_idx` ON `{{active_transfers}}` (`updated_at`);" + "CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" + "CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);" + "CREATE INDEX `{{prefix}}events_rules_updated_at_idx` ON `{{events_rules}}` (`updated_at`);" + "CREATE INDEX `{{prefix}}events_rules_deleted_at_idx` ON `{{events_rules}}` (`deleted_at`);" + "CREATE INDEX `{{prefix}}events_rules_trigger_idx` ON `{{events_rules}}` (`trigger`);" + "CREATE INDEX `{{prefix}}rules_actions_mapping_order_idx` ON `{{rules_actions_mapping}}` (`order`);" + "CREATE INDEX `{{prefix}}ip_lists_type_idx` ON `{{ip_lists}}` (`type`);" + "CREATE INDEX `{{prefix}}ip_lists_ipornet_idx` ON `{{ip_lists}}` (`ipornet`);" + "CREATE INDEX `{{prefix}}ip_lists_ip_type_idx` ON `{{ip_lists}}` (`ip_type`);" + "CREATE INDEX `{{prefix}}ip_lists_updated_at_idx` ON `{{ip_lists}}` (`updated_at`);" + "CREATE INDEX `{{prefix}}ip_lists_deleted_at_idx` ON `{{ip_lists}}` (`deleted_at`);" + "CREATE INDEX `{{prefix}}ip_lists_first_last_idx` ON `{{ip_lists}}` (`first`, `last`);" + "INSERT INTO {{schema_version}} (version) VALUES (33);" mysqlV34SQL = "CREATE TABLE `{{shares_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY," + "`share_id` integer NOT NULL, `group_id` integer NOT NULL, `permissions` integer NOT NULL," + "`sort_order` integer NOT NULL," + "CONSTRAINT `{{prefix}}unique_share_group_mapping` UNIQUE (`share_id`, `group_id`)," + "CONSTRAINT `{{prefix}}shares_groups_mapping_share_id_fk` FOREIGN KEY (`share_id`) REFERENCES `{{shares}}` (`id`) ON DELETE CASCADE," + "CONSTRAINT `{{prefix}}shares_groups_mapping_group_id_fk` FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE); " + "CREATE INDEX `{{prefix}}shares_groups_mapping_sort_order_idx` ON `{{shares_groups_mapping}}` (`sort_order`); " + "CREATE INDEX `{{prefix}}shares_groups_mapping_share_id_idx` ON `{{shares_groups_mapping}}` (`share_id`); " + "CREATE INDEX `{{prefix}}shares_groups_mapping_group_id_idx` ON `{{shares_groups_mapping}}` (`group_id`);" mysqlV34DownSQL = "DROP TABLE IF EXISTS `{{shares_groups_mapping}}`;" ) // MySQLProvider defines the auth provider for MySQL/MariaDB database type MySQLProvider struct { dbHandle *sql.DB } func init() { version.AddFeature("+mysql") } func initializeMySQLProvider() error { connString, err := getMySQLConnectionString(false) if err != nil { return err } redactedConnString, err := getMySQLConnectionString(true) if err != nil { return err } dbHandle, err := sql.Open("mysql", connString) if err != nil { providerLog(logger.LevelError, "error creating mysql database handler, connection string: %q, error: %v", redactedConnString, err) return err } providerLog(logger.LevelDebug, "mysql database handle created, connection string: %q, pool size: %v", redactedConnString, config.PoolSize) dbHandle.SetMaxOpenConns(config.PoolSize) if config.PoolSize > 0 { dbHandle.SetMaxIdleConns(config.PoolSize) } else { dbHandle.SetMaxIdleConns(2) } dbHandle.SetConnMaxLifetime(240 * time.Second) dbHandle.SetConnMaxIdleTime(120 * time.Second) provider = &MySQLProvider{dbHandle: dbHandle} ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return dbHandle.PingContext(ctx) } func getMySQLConnectionString(redactedPwd bool) (string, error) { var connectionString string if config.ConnectionString == "" { password := config.Password if redactedPwd && password != "" { password = "[redacted]" } sslMode := getSSLMode() if sslMode == "custom" && !redactedPwd { if err := registerMySQLCustomTLSConfig(); err != nil { return "", err } } connectionString = fmt.Sprintf("%s:%s@tcp([%s]:%d)/%s?collation=utf8mb4_unicode_ci&interpolateParams=true&timeout=10s&parseTime=true&clientFoundRows=true&tls=%s&writeTimeout=60s&readTimeout=60s", config.Username, password, config.Host, config.Port, config.Name, sslMode) } else { connectionString = config.ConnectionString } return connectionString, nil } func registerMySQLCustomTLSConfig() error { tlsConfig := &tls.Config{} if config.RootCert != "" { rootCAs, err := x509.SystemCertPool() if err != nil { rootCAs = x509.NewCertPool() } rootCrt, err := os.ReadFile(config.RootCert) if err != nil { return fmt.Errorf("unable to load root certificate %q: %v", config.RootCert, err) } if !rootCAs.AppendCertsFromPEM(rootCrt) { return fmt.Errorf("unable to parse root certificate %q", config.RootCert) } tlsConfig.RootCAs = rootCAs } if config.ClientCert != "" && config.ClientKey != "" { clientCert := make([]tls.Certificate, 0, 1) tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey) if err != nil { return fmt.Errorf("unable to load key pair %q, %q: %v", config.ClientCert, config.ClientKey, err) } clientCert = append(clientCert, tlsCert) tlsConfig.Certificates = clientCert } if config.SSLMode == 2 || config.SSLMode == 3 { tlsConfig.InsecureSkipVerify = true } if !filepath.IsAbs(config.Host) && !config.DisableSNI { tlsConfig.ServerName = config.Host } providerLog(logger.LevelInfo, "registering custom TLS config, root cert %q, client cert %q, client key %q, disable SNI? %v", config.RootCert, config.ClientCert, config.ClientKey, config.DisableSNI) if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil { return fmt.Errorf("unable to register tls config: %v", err) } return nil } func (p *MySQLProvider) checkAvailability() error { return sqlCommonCheckAvailability(p.dbHandle) } func (p *MySQLProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { return sqlCommonValidateUserAndPass(username, password, ip, protocol, p.dbHandle) } func (p *MySQLProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { return sqlCommonValidateUserAndTLSCertificate(username, protocol, tlsCert, p.dbHandle) } func (p *MySQLProvider) validateUserAndPubKey(username string, publicKey []byte, isSSHCert bool) (User, string, error) { return sqlCommonValidateUserAndPubKey(username, publicKey, isSSHCert, p.dbHandle) } func (p *MySQLProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) } func (p *MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) } func (p *MySQLProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { return sqlCommonGetUsedQuota(username, p.dbHandle) } func (p *MySQLProvider) getAdminSignature(username string) (string, error) { return sqlCommonGetAdminSignature(username, p.dbHandle) } func (p *MySQLProvider) getUserSignature(username string) (string, error) { return sqlCommonGetUserSignature(username, p.dbHandle) } func (p *MySQLProvider) setUpdatedAt(username string) { sqlCommonSetUpdatedAt(username, p.dbHandle) } func (p *MySQLProvider) updateLastLogin(username string) error { return sqlCommonUpdateLastLogin(username, p.dbHandle) } func (p *MySQLProvider) updateAdminLastLogin(username string) error { return sqlCommonUpdateAdminLastLogin(username, p.dbHandle) } func (p *MySQLProvider) userExists(username, role string) (User, error) { return sqlCommonGetUserByUsername(username, role, p.dbHandle) } func (p *MySQLProvider) addUser(user *User) error { return p.normalizeError(sqlCommonAddUser(user, p.dbHandle), fieldUsername) } func (p *MySQLProvider) updateUser(user *User) error { return p.normalizeError(sqlCommonUpdateUser(user, p.dbHandle), -1) } func (p *MySQLProvider) deleteUser(user User, softDelete bool) error { return sqlCommonDeleteUser(user, softDelete, p.dbHandle) } func (p *MySQLProvider) updateUserPassword(username, password string) error { return sqlCommonUpdateUserPassword(username, password, p.dbHandle) } func (p *MySQLProvider) dumpUsers() ([]User, error) { return sqlCommonDumpUsers(p.dbHandle) } func (p *MySQLProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { return sqlCommonGetRecentlyUpdatedUsers(after, p.dbHandle) } func (p *MySQLProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { return sqlCommonGetUsers(limit, offset, order, role, p.dbHandle) } func (p *MySQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) } func (p *MySQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { return sqlCommonDumpFolders(p.dbHandle) } func (p *MySQLProvider) getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) { return sqlCommonGetFolders(limit, offset, order, minimal, p.dbHandle) } func (p *MySQLProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonGetFolderByName(ctx, name, p.dbHandle) } func (p *MySQLProvider) addFolder(folder *vfs.BaseVirtualFolder) error { return p.normalizeError(sqlCommonAddFolder(folder, p.dbHandle), fieldName) } func (p *MySQLProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { return sqlCommonUpdateFolder(folder, p.dbHandle) } func (p *MySQLProvider) deleteFolder(folder vfs.BaseVirtualFolder) error { return sqlCommonDeleteFolder(folder, p.dbHandle) } func (p *MySQLProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateFolderQuota(name, filesAdd, sizeAdd, reset, p.dbHandle) } func (p *MySQLProvider) getUsedFolderQuota(name string) (int, int64, error) { return sqlCommonGetFolderUsedQuota(name, p.dbHandle) } func (p *MySQLProvider) getGroups(limit, offset int, order string, minimal bool) ([]Group, error) { return sqlCommonGetGroups(limit, offset, order, minimal, p.dbHandle) } func (p *MySQLProvider) getGroupsWithNames(names []string) ([]Group, error) { return sqlCommonGetGroupsWithNames(names, p.dbHandle) } func (p *MySQLProvider) getUsersInGroups(names []string) ([]string, error) { return sqlCommonGetUsersInGroups(names, p.dbHandle) } func (p *MySQLProvider) groupExists(name string) (Group, error) { return sqlCommonGetGroupByName(name, p.dbHandle) } func (p *MySQLProvider) addGroup(group *Group) error { return p.normalizeError(sqlCommonAddGroup(group, p.dbHandle), fieldName) } func (p *MySQLProvider) updateGroup(group *Group) error { return sqlCommonUpdateGroup(group, p.dbHandle) } func (p *MySQLProvider) deleteGroup(group Group) error { return sqlCommonDeleteGroup(group, p.dbHandle) } func (p *MySQLProvider) dumpGroups() ([]Group, error) { return sqlCommonDumpGroups(p.dbHandle) } func (p *MySQLProvider) adminExists(username string) (Admin, error) { return sqlCommonGetAdminByUsername(username, p.dbHandle) } func (p *MySQLProvider) addAdmin(admin *Admin) error { return p.normalizeError(sqlCommonAddAdmin(admin, p.dbHandle), fieldUsername) } func (p *MySQLProvider) updateAdmin(admin *Admin) error { return p.normalizeError(sqlCommonUpdateAdmin(admin, p.dbHandle), -1) } func (p *MySQLProvider) deleteAdmin(admin Admin) error { return sqlCommonDeleteAdmin(admin, p.dbHandle) } func (p *MySQLProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { return sqlCommonGetAdmins(limit, offset, order, p.dbHandle) } func (p *MySQLProvider) dumpAdmins() ([]Admin, error) { return sqlCommonDumpAdmins(p.dbHandle) } func (p *MySQLProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { return sqlCommonValidateAdminAndPass(username, password, ip, p.dbHandle) } func (p *MySQLProvider) apiKeyExists(keyID string) (APIKey, error) { return sqlCommonGetAPIKeyByID(keyID, p.dbHandle) } func (p *MySQLProvider) addAPIKey(apiKey *APIKey) error { return p.normalizeError(sqlCommonAddAPIKey(apiKey, p.dbHandle), -1) } func (p *MySQLProvider) updateAPIKey(apiKey *APIKey) error { return p.normalizeError(sqlCommonUpdateAPIKey(apiKey, p.dbHandle), -1) } func (p *MySQLProvider) deleteAPIKey(apiKey APIKey) error { return sqlCommonDeleteAPIKey(apiKey, p.dbHandle) } func (p *MySQLProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { return sqlCommonGetAPIKeys(limit, offset, order, p.dbHandle) } func (p *MySQLProvider) dumpAPIKeys() ([]APIKey, error) { return sqlCommonDumpAPIKeys(p.dbHandle) } func (p *MySQLProvider) updateAPIKeyLastUse(keyID string) error { return sqlCommonUpdateAPIKeyLastUse(keyID, p.dbHandle) } func (p *MySQLProvider) shareExists(shareID, username string) (Share, error) { return sqlCommonGetShareByID(shareID, username, p.dbHandle) } func (p *MySQLProvider) addShare(share *Share) error { return p.normalizeError(sqlCommonAddShare(share, p.dbHandle), fieldName) } func (p *MySQLProvider) updateShare(share *Share) error { return p.normalizeError(sqlCommonUpdateShare(share, p.dbHandle), -1) } func (p *MySQLProvider) deleteShare(share Share) error { return sqlCommonDeleteShare(share, p.dbHandle) } func (p *MySQLProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { return sqlCommonGetShares(limit, offset, order, username, p.dbHandle) } func (p *MySQLProvider) dumpShares() ([]Share, error) { return sqlCommonDumpShares(p.dbHandle) } func (p *MySQLProvider) updateShareLastUse(shareID string, numTokens int) error { return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle) } func (p *MySQLProvider) getDefenderHosts(from int64, limit int) ([]DefenderEntry, error) { return sqlCommonGetDefenderHosts(from, limit, p.dbHandle) } func (p *MySQLProvider) getDefenderHostByIP(ip string, from int64) (DefenderEntry, error) { return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle) } func (p *MySQLProvider) isDefenderHostBanned(ip string) (DefenderEntry, error) { return sqlCommonIsDefenderHostBanned(ip, p.dbHandle) } func (p *MySQLProvider) updateDefenderBanTime(ip string, minutes int) error { return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle) } func (p *MySQLProvider) deleteDefenderHost(ip string) error { return sqlCommonDeleteDefenderHost(ip, p.dbHandle) } func (p *MySQLProvider) addDefenderEvent(ip string, score int) error { return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle) } func (p *MySQLProvider) setDefenderBanTime(ip string, banTime int64) error { return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle) } func (p *MySQLProvider) cleanupDefender(from int64) error { return sqlCommonDefenderCleanup(from, p.dbHandle) } func (p *MySQLProvider) addActiveTransfer(transfer ActiveTransfer) error { return sqlCommonAddActiveTransfer(transfer, p.dbHandle) } func (p *MySQLProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) } func (p *MySQLProvider) removeActiveTransfer(transferID int64, connectionID string) error { return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) } func (p *MySQLProvider) cleanupActiveTransfers(before time.Time) error { return sqlCommonCleanupActiveTransfers(before, p.dbHandle) } func (p *MySQLProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { return sqlCommonGetActiveTransfers(from, p.dbHandle) } func (p *MySQLProvider) addSharedSession(session Session) error { return sqlCommonAddSession(session, p.dbHandle) } func (p *MySQLProvider) deleteSharedSession(key string, sessionType SessionType) error { return sqlCommonDeleteSession(key, sessionType, p.dbHandle) } func (p *MySQLProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { return sqlCommonGetSession(key, sessionType, p.dbHandle) } func (p *MySQLProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) } func (p *MySQLProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) { return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle) } func (p *MySQLProvider) dumpEventActions() ([]BaseEventAction, error) { return sqlCommonDumpEventActions(p.dbHandle) } func (p *MySQLProvider) eventActionExists(name string) (BaseEventAction, error) { return sqlCommonGetEventActionByName(name, p.dbHandle) } func (p *MySQLProvider) addEventAction(action *BaseEventAction) error { return p.normalizeError(sqlCommonAddEventAction(action, p.dbHandle), fieldName) } func (p *MySQLProvider) updateEventAction(action *BaseEventAction) error { return sqlCommonUpdateEventAction(action, p.dbHandle) } func (p *MySQLProvider) deleteEventAction(action BaseEventAction) error { return sqlCommonDeleteEventAction(action, p.dbHandle) } func (p *MySQLProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { return sqlCommonGetEventRules(limit, offset, order, p.dbHandle) } func (p *MySQLProvider) dumpEventRules() ([]EventRule, error) { return sqlCommonDumpEventRules(p.dbHandle) } func (p *MySQLProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle) } func (p *MySQLProvider) eventRuleExists(name string) (EventRule, error) { return sqlCommonGetEventRuleByName(name, p.dbHandle) } func (p *MySQLProvider) addEventRule(rule *EventRule) error { return p.normalizeError(sqlCommonAddEventRule(rule, p.dbHandle), fieldName) } func (p *MySQLProvider) updateEventRule(rule *EventRule) error { return sqlCommonUpdateEventRule(rule, p.dbHandle) } func (p *MySQLProvider) deleteEventRule(rule EventRule, softDelete bool) error { return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle) } func (p *MySQLProvider) getTaskByName(name string) (Task, error) { return sqlCommonGetTaskByName(name, p.dbHandle) } func (p *MySQLProvider) addTask(name string) error { return sqlCommonAddTask(name, p.dbHandle) } func (p *MySQLProvider) updateTask(name string, version int64) error { return sqlCommonUpdateTask(name, version, p.dbHandle) } func (p *MySQLProvider) updateTaskTimestamp(name string) error { return sqlCommonUpdateTaskTimestamp(name, p.dbHandle) } func (p *MySQLProvider) addNode() error { return sqlCommonAddNode(p.dbHandle) } func (p *MySQLProvider) getNodeByName(name string) (Node, error) { return sqlCommonGetNodeByName(name, p.dbHandle) } func (p *MySQLProvider) getNodes() ([]Node, error) { return sqlCommonGetNodes(p.dbHandle) } func (p *MySQLProvider) updateNodeTimestamp() error { return sqlCommonUpdateNodeTimestamp(p.dbHandle) } func (p *MySQLProvider) cleanupNodes() error { return sqlCommonCleanupNodes(p.dbHandle) } func (p *MySQLProvider) roleExists(name string) (Role, error) { return sqlCommonGetRoleByName(name, p.dbHandle) } func (p *MySQLProvider) addRole(role *Role) error { return p.normalizeError(sqlCommonAddRole(role, p.dbHandle), fieldName) } func (p *MySQLProvider) updateRole(role *Role) error { return sqlCommonUpdateRole(role, p.dbHandle) } func (p *MySQLProvider) deleteRole(role Role) error { return sqlCommonDeleteRole(role, p.dbHandle) } func (p *MySQLProvider) getRoles(limit int, offset int, order string, minimal bool) ([]Role, error) { return sqlCommonGetRoles(limit, offset, order, minimal, p.dbHandle) } func (p *MySQLProvider) dumpRoles() ([]Role, error) { return sqlCommonDumpRoles(p.dbHandle) } func (p *MySQLProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { return sqlCommonGetIPListEntry(ipOrNet, listType, p.dbHandle) } func (p *MySQLProvider) addIPListEntry(entry *IPListEntry) error { return p.normalizeError(sqlCommonAddIPListEntry(entry, p.dbHandle), fieldIPNet) } func (p *MySQLProvider) updateIPListEntry(entry *IPListEntry) error { return sqlCommonUpdateIPListEntry(entry, p.dbHandle) } func (p *MySQLProvider) deleteIPListEntry(entry IPListEntry, softDelete bool) error { return sqlCommonDeleteIPListEntry(entry, softDelete, p.dbHandle) } func (p *MySQLProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { return sqlCommonGetIPListEntries(listType, filter, from, order, limit, p.dbHandle) } func (p *MySQLProvider) getRecentlyUpdatedIPListEntries(after int64) ([]IPListEntry, error) { return sqlCommonGetRecentlyUpdatedIPListEntries(after, p.dbHandle) } func (p *MySQLProvider) dumpIPListEntries() ([]IPListEntry, error) { return sqlCommonDumpIPListEntries(p.dbHandle) } func (p *MySQLProvider) countIPListEntries(listType IPListType) (int64, error) { return sqlCommonCountIPListEntries(listType, p.dbHandle) } func (p *MySQLProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { return sqlCommonGetListEntriesForIP(ip, listType, p.dbHandle) } func (p *MySQLProvider) getConfigs() (Configs, error) { return sqlCommonGetConfigs(p.dbHandle) } func (p *MySQLProvider) setConfigs(configs *Configs) error { return sqlCommonSetConfigs(configs, p.dbHandle) } func (p *MySQLProvider) setFirstDownloadTimestamp(username string) error { return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle) } func (p *MySQLProvider) setFirstUploadTimestamp(username string) error { return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle) } func (p *MySQLProvider) close() error { return p.dbHandle.Close() } func (p *MySQLProvider) reloadConfig() error { return nil } // initializeDatabase creates the initial database structure func (p *MySQLProvider) initializeDatabase() error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false) if err == nil && dbVersion.Version > 0 { return ErrNoInitRequired } if errors.Is(err, sql.ErrNoRows) { return errSchemaVersionEmpty } logger.InfoToConsole("creating initial database schema, version 33") providerLog(logger.LevelInfo, "creating initial database schema, version 33") initialSQL := sqlReplaceAll(mysqlInitialSQL) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 33, true) } func (p *MySQLProvider) migrateDatabase() error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err } switch version := dbVersion.Version; { case version == sqlDatabaseVersion: providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version) return ErrNoInitRequired case version < 33: err = errSchemaVersionTooOld(version) providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err case version == 33: return updateMySQLDatabaseFromV33(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, sqlDatabaseVersion) logger.WarnToConsole("database schema version %d is newer than the supported one: %d", version, sqlDatabaseVersion) return nil } return fmt.Errorf("database schema version not handled: %d", version) } } func (p *MySQLProvider) revertDatabase(targetVersion int) error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err } if dbVersion.Version == targetVersion { return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { case 34: return downgradeMySQLDatabaseFromV34(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } } func (p *MySQLProvider) resetDatabase() error { sql := sqlReplaceAll(mysqlResetSQL) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0, false) } func (p *MySQLProvider) normalizeError(err error, fieldType int) error { if err == nil { return nil } var mysqlErr *mysql.MySQLError if errors.As(err, &mysqlErr) { switch mysqlErr.Number { case 1062: var message string switch fieldType { case fieldUsername: message = util.I18nErrorDuplicatedUsername case fieldIPNet: message = util.I18nErrorDuplicatedIPNet default: message = util.I18nErrorDuplicatedName } return util.NewI18nError( fmt.Errorf("%w: %s", ErrDuplicatedKey, err.Error()), message, ) case 1452: return fmt.Errorf("%w: %s", ErrForeignKeyViolated, err.Error()) } } return err } func updateMySQLDatabaseFromV33(dbHandle *sql.DB) error { return updateMySQLDatabaseFrom33To34(dbHandle) } func downgradeMySQLDatabaseFromV34(dbHandle *sql.DB) error { return downgradeMySQLDatabaseFrom34To33(dbHandle) } func updateMySQLDatabaseFrom33To34(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 33 -> 34") providerLog(logger.LevelInfo, "updating database schema version: 33 -> 34") sql := strings.ReplaceAll(mysqlV34SQL, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) sql = strings.ReplaceAll(sql, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 34, true) } func downgradeMySQLDatabaseFrom34To33(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database schema version: 34 -> 33") providerLog(logger.LevelInfo, "downgrading database schema version: 34 -> 33") sql := strings.ReplaceAll(mysqlV34DownSQL, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 33, false) } ================================================ FILE: internal/dataprovider/mysql_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build nomysql package dataprovider import ( "errors" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-mysql") } func initializeMySQLProvider() error { return errors.New("MySQL disabled at build time") } ================================================ FILE: internal/dataprovider/node.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "strings" "time" "github.com/go-jose/go-jose/v4" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) // Supported protocols for connecting to other nodes const ( NodeProtoHTTP = "http" NodeProtoHTTPS = "https" ) const ( // NodeTokenHeader defines the header to use for the node auth token NodeTokenHeader = "X-SFTPGO-Node" nodeTokenAudience = "node" ) var ( // current node currentNode *Node errNoClusterNodes = errors.New("no cluster node defined") activeNodeTimeDiff = -2 * time.Minute nodeReqTimeout = 8 * time.Second ) // NodeConfig defines the node configuration type NodeConfig struct { Host string `json:"host" mapstructure:"host"` Port int `json:"port" mapstructure:"port"` Proto string `json:"proto" mapstructure:"proto"` } func (n *NodeConfig) validate() error { currentNode = nil if config.IsShared != 1 { return nil } if n.Host == "" { return nil } currentNode = &Node{ Data: NodeData{ Host: n.Host, Port: n.Port, Proto: n.Proto, }, } return provider.addNode() } // NodeData defines the details to connect to a cluster node type NodeData struct { Host string `json:"host"` Port int `json:"port"` Proto string `json:"proto"` Key *kms.Secret `json:"api_key"` } func (n *NodeData) validate() error { if n.Host == "" { return util.NewValidationError("node host is mandatory") } if n.Port < 0 || n.Port > 65535 { return util.NewValidationError(fmt.Sprintf("invalid node port: %d", n.Port)) } if n.Proto != NodeProtoHTTP && n.Proto != NodeProtoHTTPS { return util.NewValidationError(fmt.Sprintf("invalid node proto: %s", n.Proto)) } n.Key = kms.NewPlainSecret(util.GenerateOpaqueString()) n.Key.SetAdditionalData(n.Host) if err := n.Key.Encrypt(); err != nil { return fmt.Errorf("unable to encrypt node key: %w", err) } return nil } func (n *NodeData) getNodeName() string { h := sha256.New() var b bytes.Buffer fmt.Fprintf(&b, "%s:%d", n.Host, n.Port) h.Write(b.Bytes()) return hex.EncodeToString(h.Sum(nil)) } // Node defines a cluster node type Node struct { Name string `json:"name"` Data NodeData `json:"data"` CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` } func (n *Node) validate() error { if n.Name == "" { n.Name = n.Data.getNodeName() } return n.Data.validate() } func (n *Node) authenticate(token string) (*jwt.Claims, error) { if err := n.Data.Key.TryDecrypt(); err != nil { providerLog(logger.LevelError, "unable to decrypt node key: %v", err) return nil, err } if token == "" { return nil, ErrInvalidCredentials } claims, err := jwt.VerifyTokenWithKey(token, []jose.SignatureAlgorithm{jose.HS256}, []byte(n.Data.Key.GetPayload())) if err != nil { return nil, fmt.Errorf("unable to parse and validate token: %v", err) } if claims.Username == "" { return nil, errors.New("no admin username associated with node token") } if !claims.Audience.Contains(nodeTokenAudience) { return nil, errors.New("invalid node token audience") } return claims, nil } // getBaseURL returns the base URL for this node func (n *Node) getBaseURL() string { var sb strings.Builder sb.WriteString(n.Data.Proto) sb.WriteString("://") sb.WriteString(n.Data.Host) if n.Data.Port > 0 { sb.WriteString(":") sb.WriteString(strconv.Itoa(n.Data.Port)) } return sb.String() } // generateAuthToken generates a new auth token func (n *Node) generateAuthToken(username, role string, permissions []string) (string, error) { if err := n.Data.Key.TryDecrypt(); err != nil { return "", fmt.Errorf("unable to decrypt node key: %w", err) } signer, err := jwt.NewSigner(jose.HS256, []byte(n.Data.Key.GetPayload())) if err != nil { return "", fmt.Errorf("unable to create signer: %w", err) } claims := &jwt.Claims{ Username: username, Role: role, Permissions: permissions, } claims.Audience = []string{nodeTokenAudience} claims.SetExpiry(time.Now().Add(1 * time.Minute)) payload, err := signer.Sign(claims) if err != nil { return "", fmt.Errorf("unable to sign authentication token: %w", err) } return payload, nil } func (n *Node) prepareRequest(ctx context.Context, username, role, relativeURL, method string, permissions []string, body io.Reader, ) (*http.Request, error) { url := fmt.Sprintf("%s%s", n.getBaseURL(), relativeURL) req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err } token, err := n.generateAuthToken(username, role, permissions) if err != nil { return nil, err } req.Header.Set(NodeTokenHeader, fmt.Sprintf("Bearer %s", token)) return req, nil } // SendGetRequest sends an HTTP GET request to this node. // The responseHolder must be a pointer func (n *Node) SendGetRequest(username, role, relativeURL string, permissions []string, responseHolder any) error { ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout) defer cancel() req, err := n.prepareRequest(ctx, username, role, relativeURL, http.MethodGet, permissions, nil) if err != nil { return err } client := httpclient.GetHTTPClient() defer client.CloseIdleConnections() resp, err := client.Do(req) if err != nil { return fmt.Errorf("unable to send HTTP GET to node %s: %w", n.Name, err) } defer resp.Body.Close() if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent { return fmt.Errorf("unexpected status code: %d", resp.StatusCode) } respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10485760)) if err != nil { return fmt.Errorf("unable to read response body: %w", err) } err = json.Unmarshal(respBody, responseHolder) if err != nil { return errors.New("unable to decode response as json") } return nil } // SendDeleteRequest sends an HTTP DELETE request to this node func (n *Node) SendDeleteRequest(username, role, relativeURL string, permissions []string) error { ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout) defer cancel() req, err := n.prepareRequest(ctx, username, role, relativeURL, http.MethodDelete, permissions, nil) if err != nil { return err } client := httpclient.GetHTTPClient() defer client.CloseIdleConnections() resp, err := client.Do(req) if err != nil { return fmt.Errorf("unable to send HTTP DELETE to node %s: %w", n.Name, err) } defer resp.Body.Close() if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent { return fmt.Errorf("unexpected status code: %d", resp.StatusCode) } return nil } // AuthenticateNodeToken check the validity of the provided token func AuthenticateNodeToken(token string) (*jwt.Claims, error) { if currentNode == nil { return nil, errNoClusterNodes } return currentNode.authenticate(token) } // GetNodeName returns the node name or an empty string func GetNodeName() string { if currentNode == nil { return "" } return currentNode.Name } ================================================ FILE: internal/dataprovider/pgsql.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !nopgsql package dataprovider import ( "context" "crypto/x509" "database/sql" "errors" "fmt" "net" "slices" "strconv" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/stdlib" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( pgsqlResetSQL = `DROP TABLE IF EXISTS "{{api_keys}}" CASCADE; DROP TABLE IF EXISTS "{{users_folders_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{users_groups_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{admins_groups_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{groups_folders_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{shares_groups_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{admins}}" CASCADE; DROP TABLE IF EXISTS "{{folders}}" CASCADE; DROP TABLE IF EXISTS "{{shares}}" CASCADE; DROP TABLE IF EXISTS "{{users}}" CASCADE; DROP TABLE IF EXISTS "{{groups}}" CASCADE; DROP TABLE IF EXISTS "{{defender_events}}" CASCADE; DROP TABLE IF EXISTS "{{defender_hosts}}" CASCADE; DROP TABLE IF EXISTS "{{active_transfers}}" CASCADE; DROP TABLE IF EXISTS "{{shared_sessions}}" CASCADE; DROP TABLE IF EXISTS "{{rules_actions_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{events_actions}}" CASCADE; DROP TABLE IF EXISTS "{{events_rules}}" CASCADE; DROP TABLE IF EXISTS "{{tasks}}" CASCADE; DROP TABLE IF EXISTS "{{nodes}}" CASCADE; DROP TABLE IF EXISTS "{{roles}}" CASCADE; DROP TABLE IF EXISTS "{{ip_lists}}" CASCADE; DROP TABLE IF EXISTS "{{configs}}" CASCADE; DROP TABLE IF EXISTS "{{schema_version}}" CASCADE; ` pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "version" integer NOT NULL); CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "username" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, "permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL, "role_id" integer NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); CREATE TABLE "{{active_transfers}}" ("id" bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "connection_id" varchar(100) NOT NULL, "transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL, "folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL, "current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); CREATE TABLE "{{defender_hosts}}" ("id" bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "ip" varchar(50) NOT NULL UNIQUE, "ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL); CREATE TABLE "{{defender_events}}" ("id" bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "date_time" bigint NOT NULL, "score" integer NOT NULL, "host_id" bigint NOT NULL); ALTER TABLE "{{defender_events}}" ADD CONSTRAINT "{{prefix}}defender_events_host_id_fk_defender_hosts_id" FOREIGN KEY ("host_id") REFERENCES "{{defender_hosts}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; CREATE TABLE "{{folders}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "filesystem" text NULL); CREATE TABLE "{{groups}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL); CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL, "type" integer NOT NULL, "data" text NOT NULL, "timestamp" bigint NOT NULL, PRIMARY KEY ("key", "type")); CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "username" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL, "expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL, "public_keys" text NULL, "home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL, "filesystem" text NULL, "additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL, "upload_data_transfer" integer NOT NULL, "download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL, "used_upload_data_transfer" bigint NOT NULL, "used_download_data_transfer" bigint NOT NULL, "deleted_at" bigint NOT NULL, "first_download" bigint NOT NULL, "first_upload" bigint NOT NULL, "last_password_change" bigint NOT NULL, "role_id" integer NULL); CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "group_id" integer NOT NULL, "folder_id" integer NOT NULL, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "sort_order" integer NOT NULL); CREATE TABLE "{{users_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "user_id" integer NOT NULL, "group_id" integer NOT NULL, "group_type" integer NOT NULL, "sort_order" integer NOT NULL); CREATE TABLE "{{users_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "sort_order" integer NOT NULL, "folder_id" integer NOT NULL, "user_id" integer NOT NULL); ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id"); ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}users_folders_mapping_folder_id_fk_folders_id" FOREIGN KEY ("folder_id") REFERENCES "{{folders}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}users_folders_mapping_user_id_fk_users_id" FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; CREATE TABLE "{{shares}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "share_id" varchar(60) NOT NULL UNIQUE, "name" varchar(255) NOT NULL, "description" varchar(512) NULL, "scope" integer NOT NULL, "paths" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL, "password" text NULL, "max_tokens" integer NOT NULL, "used_tokens" integer NOT NULL, "allow_from" text NULL, "options" text NULL, "user_id" integer NOT NULL); ALTER TABLE "{{shares}}" ADD CONSTRAINT "{{prefix}}shares_user_id_fk_users_id" FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; CREATE TABLE "{{api_keys}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL, "key_id" varchar(50) NOT NULL UNIQUE, "api_key" varchar(255) NOT NULL UNIQUE, "scope" integer NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL,"expires_at" bigint NOT NULL, "description" text NULL, "admin_id" integer NULL, "user_id" integer NULL); ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_admin_id_fk_admins_id" FOREIGN KEY ("admin_id") REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_user_id_fk_users_id" FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id"); ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id"); CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id"); ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}users_groups_mapping_group_id_fk_groups_id" FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE NO ACTION; CREATE INDEX "{{prefix}}users_groups_mapping_user_id_idx" ON "{{users_groups_mapping}}" ("user_id"); ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}users_groups_mapping_user_id_fk_users_id" FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; CREATE INDEX "{{prefix}}users_groups_mapping_sort_order_idx" ON "{{users_groups_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}groups_folders_mapping_folder_id_idx" ON "{{groups_folders_mapping}}" ("folder_id"); ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}groups_folders_mapping_folder_id_fk_folders_id" FOREIGN KEY ("folder_id") REFERENCES "{{folders}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folders_mapping}}" ("group_id"); ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}groups_folders_mapping_group_id_fk_groups_id" FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; CREATE INDEX "{{prefix}}groups_folders_mapping_sort_order_idx" ON "{{groups_folders_mapping}}" ("sort_order"); CREATE TABLE "{{events_rules}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL, "description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "trigger" integer NOT NULL, "conditions" text NOT NULL, "deleted_at" bigint NOT NULL); CREATE TABLE "{{events_actions}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "type" integer NOT NULL, "options" text NOT NULL); CREATE TABLE "{{rules_actions_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "rule_id" integer NOT NULL, "action_id" integer NOT NULL, "order" integer NOT NULL, "options" text NOT NULL); CREATE TABLE "{{tasks}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "updated_at" bigint NOT NULL, "version" bigint NOT NULL); ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}unique_rule_action_mapping" UNIQUE ("rule_id", "action_id"); ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}rules_actions_mapping_rule_id_fk_events_rules_id" FOREIGN KEY ("rule_id") REFERENCES "{{events_rules}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}rules_actions_mapping_action_id_fk_events_targets_id" FOREIGN KEY ("action_id") REFERENCES "{{events_actions}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE NO ACTION; CREATE TABLE "{{admins_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "admin_id" integer NOT NULL, "group_id" integer NOT NULL, "options" text NOT NULL, "sort_order" integer NOT NULL); ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_admin_group_mapping" UNIQUE ("admin_id", "group_id"); ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}admins_groups_mapping_admin_id_fk_admins_id" FOREIGN KEY ("admin_id") REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}admins_groups_mapping_group_id_fk_groups_id" FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; CREATE TABLE "{{nodes}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "data" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); CREATE TABLE "{{roles}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); ALTER TABLE "{{admins}}" ADD CONSTRAINT "{{prefix}}admins_role_id_fk_roles_id" FOREIGN KEY ("role_id") REFERENCES "{{roles}}" ("id") ON DELETE NO ACTION; ALTER TABLE "{{users}}" ADD CONSTRAINT "{{prefix}}users_role_id_fk_roles_id" FOREIGN KEY ("role_id") REFERENCES "{{roles}}" ("id") ON DELETE SET NULL; CREATE TABLE "{{ip_lists}}" ("id" bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "type" integer NOT NULL, "ipornet" varchar(50) NOT NULL, "mode" integer NOT NULL, "description" varchar(512) NULL, "first" inet NOT NULL, "last" inet NOT NULL, "ip_type" integer NOT NULL, "protocols" integer NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "deleted_at" bigint NOT NULL); ALTER TABLE "{{ip_lists}}" ADD CONSTRAINT "{{prefix}}unique_ipornet_type_mapping" UNIQUE ("type", "ipornet"); CREATE TABLE "{{configs}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "configs" text NOT NULL); INSERT INTO {{configs}} (configs) VALUES ('{}'); CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id"); CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id"); CREATE INDEX "{{prefix}}users_folders_mapping_sort_order_idx" ON "{{users_folders_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id"); CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id"); CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at"); CREATE INDEX "{{prefix}}users_deleted_at_idx" ON "{{users}}" ("deleted_at"); CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id"); CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at"); CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time"); CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time"); CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id"); CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id"); CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id"); CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at"); CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); CREATE INDEX "{{prefix}}events_rules_updated_at_idx" ON "{{events_rules}}" ("updated_at"); CREATE INDEX "{{prefix}}events_rules_deleted_at_idx" ON "{{events_rules}}" ("deleted_at"); CREATE INDEX "{{prefix}}events_rules_trigger_idx" ON "{{events_rules}}" ("trigger"); CREATE INDEX "{{prefix}}rules_actions_mapping_rule_id_idx" ON "{{rules_actions_mapping}}" ("rule_id"); CREATE INDEX "{{prefix}}rules_actions_mapping_action_id_idx" ON "{{rules_actions_mapping}}" ("action_id"); CREATE INDEX "{{prefix}}rules_actions_mapping_order_idx" ON "{{rules_actions_mapping}}" ("order"); CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id"); CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id"); CREATE INDEX "{{prefix}}admins_groups_mapping_sort_order_idx" ON "{{admins_groups_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}admins_role_id_idx" ON "{{admins}}" ("role_id"); CREATE INDEX "{{prefix}}users_role_id_idx" ON "{{users}}" ("role_id"); CREATE INDEX "{{prefix}}ip_lists_type_idx" ON "{{ip_lists}}" ("type"); CREATE INDEX "{{prefix}}ip_lists_ipornet_idx" ON "{{ip_lists}}" ("ipornet"); CREATE INDEX "{{prefix}}ip_lists_updated_at_idx" ON "{{ip_lists}}" ("updated_at"); CREATE INDEX "{{prefix}}ip_lists_deleted_at_idx" ON "{{ip_lists}}" ("deleted_at"); CREATE INDEX "{{prefix}}ip_lists_first_last_idx" ON "{{ip_lists}}" ("first", "last"); INSERT INTO {{schema_version}} (version) VALUES (33); ` // not supported in CockroachDB ipListsLikeIndex = `CREATE INDEX "{{prefix}}ip_lists_ipornet_like_idx" ON "{{ip_lists}}" ("ipornet" varchar_pattern_ops);` pgsqlV34SQL = `CREATE TABLE "{{shares_groups_mapping}}" ( "id" integer NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, "share_id" integer NOT NULL, "group_id" integer NOT NULL, "permissions" integer NOT NULL, "sort_order" integer NOT NULL, CONSTRAINT "{{prefix}}unique_share_group_mapping" UNIQUE ("share_id", "group_id"), CONSTRAINT "{{prefix}}shares_groups_mapping_share_id_fk" FOREIGN KEY ("share_id") REFERENCES "{{shares}}"("id") ON DELETE CASCADE, CONSTRAINT "{{prefix}}shares_groups_mapping_group_id_fk" FOREIGN KEY ("group_id") REFERENCES "{{groups}}"("id") ON DELETE CASCADE); CREATE INDEX "{{prefix}}shares_groups_mapping_sort_order_idx" ON "{{shares_groups_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}shares_groups_mapping_share_id_idx" ON "{{shares_groups_mapping}}" ("share_id"); CREATE INDEX "{{prefix}}shares_groups_mapping_group_id_idx" ON "{{shares_groups_mapping}}" ("group_id"); ` pgsqlV34DownSQL = `DROP TABLE IF EXISTS "{{shares_groups_mapping}}";` ) var ( pgSQLTargetSessionAttrs = []string{"any", "read-write", "read-only", "primary", "standby", "prefer-standby"} ) // PGSQLProvider defines the auth provider for PostgreSQL database type PGSQLProvider struct { dbHandle *sql.DB } func init() { version.AddFeature("+pgsql") } func initializePGSQLProvider() error { var dbHandle *sql.DB if config.TargetSessionAttrs == "any" { pgxConfig, err := pgx.ParseConfig(getPGSQLConnectionString(false)) if err != nil { providerLog(logger.LevelError, "error parsing postgres configuration, connection string: %q, error: %v", getPGSQLConnectionString(true), err) return err } dbHandle = stdlib.OpenDB(*pgxConfig, stdlib.OptionBeforeConnect(stdlib.RandomizeHostOrderFunc)) } else { var err error dbHandle, err = sql.Open("pgx", getPGSQLConnectionString(false)) if err != nil { providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v", getPGSQLConnectionString(true), err) return err } } providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d", getPGSQLConnectionString(true), config.PoolSize) dbHandle.SetMaxOpenConns(config.PoolSize) if config.PoolSize > 0 { dbHandle.SetMaxIdleConns(config.PoolSize) } else { dbHandle.SetMaxIdleConns(2) } dbHandle.SetConnMaxLifetime(240 * time.Second) dbHandle.SetConnMaxIdleTime(120 * time.Second) provider = &PGSQLProvider{dbHandle: dbHandle} ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return dbHandle.PingContext(ctx) } func getPGSQLHostsAndPorts(configHost string, configPort int) (string, string) { var hosts, ports []string defaultPort := strconv.Itoa(configPort) if defaultPort == "0" { defaultPort = "5432" } for hostport := range strings.SplitSeq(configHost, ",") { hostport = strings.TrimSpace(hostport) if hostport == "" { continue } host, port, err := net.SplitHostPort(hostport) if err == nil { hosts = append(hosts, host) ports = append(ports, port) } else { hosts = append(hosts, hostport) ports = append(ports, defaultPort) } } return strings.Join(hosts, ","), strings.Join(ports, ",") } func getPGSQLConnectionString(redactedPwd bool) string { var connectionString string if config.ConnectionString == "" { password := config.Password if redactedPwd && password != "" { password = "[redacted]" } host, port := getPGSQLHostsAndPorts(config.Host, config.Port) connectionString = fmt.Sprintf("host='%s' port='%s' dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10", host, port, config.Name, config.Username, password, getSSLMode()) if config.RootCert != "" { connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert) } if config.ClientCert != "" && config.ClientKey != "" { connectionString += fmt.Sprintf(" sslcert='%s' sslkey='%s'", config.ClientCert, config.ClientKey) } if config.DisableSNI { connectionString += " sslsni=0" } if slices.Contains(pgSQLTargetSessionAttrs, config.TargetSessionAttrs) { connectionString += fmt.Sprintf(" target_session_attrs='%s'", config.TargetSessionAttrs) } } else { connectionString = config.ConnectionString } return connectionString } func (p *PGSQLProvider) checkAvailability() error { return sqlCommonCheckAvailability(p.dbHandle) } func (p *PGSQLProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { return sqlCommonValidateUserAndPass(username, password, ip, protocol, p.dbHandle) } func (p *PGSQLProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { return sqlCommonValidateUserAndTLSCertificate(username, protocol, tlsCert, p.dbHandle) } func (p *PGSQLProvider) validateUserAndPubKey(username string, publicKey []byte, isSSHCert bool) (User, string, error) { return sqlCommonValidateUserAndPubKey(username, publicKey, isSSHCert, p.dbHandle) } func (p *PGSQLProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) } func (p *PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) } func (p *PGSQLProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { return sqlCommonGetUsedQuota(username, p.dbHandle) } func (p *PGSQLProvider) getAdminSignature(username string) (string, error) { return sqlCommonGetAdminSignature(username, p.dbHandle) } func (p *PGSQLProvider) getUserSignature(username string) (string, error) { return sqlCommonGetUserSignature(username, p.dbHandle) } func (p *PGSQLProvider) setUpdatedAt(username string) { sqlCommonSetUpdatedAt(username, p.dbHandle) } func (p *PGSQLProvider) updateLastLogin(username string) error { return sqlCommonUpdateLastLogin(username, p.dbHandle) } func (p *PGSQLProvider) updateAdminLastLogin(username string) error { return sqlCommonUpdateAdminLastLogin(username, p.dbHandle) } func (p *PGSQLProvider) userExists(username, role string) (User, error) { return sqlCommonGetUserByUsername(username, role, p.dbHandle) } func (p *PGSQLProvider) addUser(user *User) error { return p.normalizeError(sqlCommonAddUser(user, p.dbHandle), fieldUsername) } func (p *PGSQLProvider) updateUser(user *User) error { return p.normalizeError(sqlCommonUpdateUser(user, p.dbHandle), -1) } func (p *PGSQLProvider) deleteUser(user User, softDelete bool) error { return sqlCommonDeleteUser(user, softDelete, p.dbHandle) } func (p *PGSQLProvider) updateUserPassword(username, password string) error { return sqlCommonUpdateUserPassword(username, password, p.dbHandle) } func (p *PGSQLProvider) dumpUsers() ([]User, error) { return sqlCommonDumpUsers(p.dbHandle) } func (p *PGSQLProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { return sqlCommonGetRecentlyUpdatedUsers(after, p.dbHandle) } func (p *PGSQLProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { return sqlCommonGetUsers(limit, offset, order, role, p.dbHandle) } func (p *PGSQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) } func (p *PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { return sqlCommonDumpFolders(p.dbHandle) } func (p *PGSQLProvider) getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) { return sqlCommonGetFolders(limit, offset, order, minimal, p.dbHandle) } func (p *PGSQLProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonGetFolderByName(ctx, name, p.dbHandle) } func (p *PGSQLProvider) addFolder(folder *vfs.BaseVirtualFolder) error { return p.normalizeError(sqlCommonAddFolder(folder, p.dbHandle), fieldName) } func (p *PGSQLProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { return sqlCommonUpdateFolder(folder, p.dbHandle) } func (p *PGSQLProvider) deleteFolder(folder vfs.BaseVirtualFolder) error { return sqlCommonDeleteFolder(folder, p.dbHandle) } func (p *PGSQLProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateFolderQuota(name, filesAdd, sizeAdd, reset, p.dbHandle) } func (p *PGSQLProvider) getUsedFolderQuota(name string) (int, int64, error) { return sqlCommonGetFolderUsedQuota(name, p.dbHandle) } func (p *PGSQLProvider) getGroups(limit, offset int, order string, minimal bool) ([]Group, error) { return sqlCommonGetGroups(limit, offset, order, minimal, p.dbHandle) } func (p *PGSQLProvider) getGroupsWithNames(names []string) ([]Group, error) { return sqlCommonGetGroupsWithNames(names, p.dbHandle) } func (p *PGSQLProvider) getUsersInGroups(names []string) ([]string, error) { return sqlCommonGetUsersInGroups(names, p.dbHandle) } func (p *PGSQLProvider) groupExists(name string) (Group, error) { return sqlCommonGetGroupByName(name, p.dbHandle) } func (p *PGSQLProvider) addGroup(group *Group) error { return p.normalizeError(sqlCommonAddGroup(group, p.dbHandle), fieldName) } func (p *PGSQLProvider) updateGroup(group *Group) error { return sqlCommonUpdateGroup(group, p.dbHandle) } func (p *PGSQLProvider) deleteGroup(group Group) error { return sqlCommonDeleteGroup(group, p.dbHandle) } func (p *PGSQLProvider) dumpGroups() ([]Group, error) { return sqlCommonDumpGroups(p.dbHandle) } func (p *PGSQLProvider) adminExists(username string) (Admin, error) { return sqlCommonGetAdminByUsername(username, p.dbHandle) } func (p *PGSQLProvider) addAdmin(admin *Admin) error { return p.normalizeError(sqlCommonAddAdmin(admin, p.dbHandle), fieldUsername) } func (p *PGSQLProvider) updateAdmin(admin *Admin) error { return p.normalizeError(sqlCommonUpdateAdmin(admin, p.dbHandle), -1) } func (p *PGSQLProvider) deleteAdmin(admin Admin) error { return sqlCommonDeleteAdmin(admin, p.dbHandle) } func (p *PGSQLProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { return sqlCommonGetAdmins(limit, offset, order, p.dbHandle) } func (p *PGSQLProvider) dumpAdmins() ([]Admin, error) { return sqlCommonDumpAdmins(p.dbHandle) } func (p *PGSQLProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { return sqlCommonValidateAdminAndPass(username, password, ip, p.dbHandle) } func (p *PGSQLProvider) apiKeyExists(keyID string) (APIKey, error) { return sqlCommonGetAPIKeyByID(keyID, p.dbHandle) } func (p *PGSQLProvider) addAPIKey(apiKey *APIKey) error { return p.normalizeError(sqlCommonAddAPIKey(apiKey, p.dbHandle), -1) } func (p *PGSQLProvider) updateAPIKey(apiKey *APIKey) error { return p.normalizeError(sqlCommonUpdateAPIKey(apiKey, p.dbHandle), -1) } func (p *PGSQLProvider) deleteAPIKey(apiKey APIKey) error { return sqlCommonDeleteAPIKey(apiKey, p.dbHandle) } func (p *PGSQLProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { return sqlCommonGetAPIKeys(limit, offset, order, p.dbHandle) } func (p *PGSQLProvider) dumpAPIKeys() ([]APIKey, error) { return sqlCommonDumpAPIKeys(p.dbHandle) } func (p *PGSQLProvider) updateAPIKeyLastUse(keyID string) error { return sqlCommonUpdateAPIKeyLastUse(keyID, p.dbHandle) } func (p *PGSQLProvider) shareExists(shareID, username string) (Share, error) { return sqlCommonGetShareByID(shareID, username, p.dbHandle) } func (p *PGSQLProvider) addShare(share *Share) error { return p.normalizeError(sqlCommonAddShare(share, p.dbHandle), fieldName) } func (p *PGSQLProvider) updateShare(share *Share) error { return p.normalizeError(sqlCommonUpdateShare(share, p.dbHandle), -1) } func (p *PGSQLProvider) deleteShare(share Share) error { return sqlCommonDeleteShare(share, p.dbHandle) } func (p *PGSQLProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { return sqlCommonGetShares(limit, offset, order, username, p.dbHandle) } func (p *PGSQLProvider) dumpShares() ([]Share, error) { return sqlCommonDumpShares(p.dbHandle) } func (p *PGSQLProvider) updateShareLastUse(shareID string, numTokens int) error { return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle) } func (p *PGSQLProvider) getDefenderHosts(from int64, limit int) ([]DefenderEntry, error) { return sqlCommonGetDefenderHosts(from, limit, p.dbHandle) } func (p *PGSQLProvider) getDefenderHostByIP(ip string, from int64) (DefenderEntry, error) { return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle) } func (p *PGSQLProvider) isDefenderHostBanned(ip string) (DefenderEntry, error) { return sqlCommonIsDefenderHostBanned(ip, p.dbHandle) } func (p *PGSQLProvider) updateDefenderBanTime(ip string, minutes int) error { return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle) } func (p *PGSQLProvider) deleteDefenderHost(ip string) error { return sqlCommonDeleteDefenderHost(ip, p.dbHandle) } func (p *PGSQLProvider) addDefenderEvent(ip string, score int) error { return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle) } func (p *PGSQLProvider) setDefenderBanTime(ip string, banTime int64) error { return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle) } func (p *PGSQLProvider) cleanupDefender(from int64) error { return sqlCommonDefenderCleanup(from, p.dbHandle) } func (p *PGSQLProvider) addActiveTransfer(transfer ActiveTransfer) error { return sqlCommonAddActiveTransfer(transfer, p.dbHandle) } func (p *PGSQLProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) } func (p *PGSQLProvider) removeActiveTransfer(transferID int64, connectionID string) error { return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) } func (p *PGSQLProvider) cleanupActiveTransfers(before time.Time) error { return sqlCommonCleanupActiveTransfers(before, p.dbHandle) } func (p *PGSQLProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { return sqlCommonGetActiveTransfers(from, p.dbHandle) } func (p *PGSQLProvider) addSharedSession(session Session) error { return sqlCommonAddSession(session, p.dbHandle) } func (p *PGSQLProvider) deleteSharedSession(key string, sessionType SessionType) error { return sqlCommonDeleteSession(key, sessionType, p.dbHandle) } func (p *PGSQLProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { return sqlCommonGetSession(key, sessionType, p.dbHandle) } func (p *PGSQLProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) } func (p *PGSQLProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) { return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle) } func (p *PGSQLProvider) dumpEventActions() ([]BaseEventAction, error) { return sqlCommonDumpEventActions(p.dbHandle) } func (p *PGSQLProvider) eventActionExists(name string) (BaseEventAction, error) { return sqlCommonGetEventActionByName(name, p.dbHandle) } func (p *PGSQLProvider) addEventAction(action *BaseEventAction) error { return p.normalizeError(sqlCommonAddEventAction(action, p.dbHandle), fieldName) } func (p *PGSQLProvider) updateEventAction(action *BaseEventAction) error { return sqlCommonUpdateEventAction(action, p.dbHandle) } func (p *PGSQLProvider) deleteEventAction(action BaseEventAction) error { return sqlCommonDeleteEventAction(action, p.dbHandle) } func (p *PGSQLProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { return sqlCommonGetEventRules(limit, offset, order, p.dbHandle) } func (p *PGSQLProvider) dumpEventRules() ([]EventRule, error) { return sqlCommonDumpEventRules(p.dbHandle) } func (p *PGSQLProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle) } func (p *PGSQLProvider) eventRuleExists(name string) (EventRule, error) { return sqlCommonGetEventRuleByName(name, p.dbHandle) } func (p *PGSQLProvider) addEventRule(rule *EventRule) error { return p.normalizeError(sqlCommonAddEventRule(rule, p.dbHandle), fieldName) } func (p *PGSQLProvider) updateEventRule(rule *EventRule) error { return sqlCommonUpdateEventRule(rule, p.dbHandle) } func (p *PGSQLProvider) deleteEventRule(rule EventRule, softDelete bool) error { return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle) } func (p *PGSQLProvider) getTaskByName(name string) (Task, error) { return sqlCommonGetTaskByName(name, p.dbHandle) } func (p *PGSQLProvider) addTask(name string) error { return sqlCommonAddTask(name, p.dbHandle) } func (p *PGSQLProvider) updateTask(name string, version int64) error { return sqlCommonUpdateTask(name, version, p.dbHandle) } func (p *PGSQLProvider) updateTaskTimestamp(name string) error { return sqlCommonUpdateTaskTimestamp(name, p.dbHandle) } func (p *PGSQLProvider) addNode() error { return sqlCommonAddNode(p.dbHandle) } func (p *PGSQLProvider) getNodeByName(name string) (Node, error) { return sqlCommonGetNodeByName(name, p.dbHandle) } func (p *PGSQLProvider) getNodes() ([]Node, error) { return sqlCommonGetNodes(p.dbHandle) } func (p *PGSQLProvider) updateNodeTimestamp() error { return sqlCommonUpdateNodeTimestamp(p.dbHandle) } func (p *PGSQLProvider) cleanupNodes() error { return sqlCommonCleanupNodes(p.dbHandle) } func (p *PGSQLProvider) roleExists(name string) (Role, error) { return sqlCommonGetRoleByName(name, p.dbHandle) } func (p *PGSQLProvider) addRole(role *Role) error { return p.normalizeError(sqlCommonAddRole(role, p.dbHandle), fieldName) } func (p *PGSQLProvider) updateRole(role *Role) error { return sqlCommonUpdateRole(role, p.dbHandle) } func (p *PGSQLProvider) deleteRole(role Role) error { return sqlCommonDeleteRole(role, p.dbHandle) } func (p *PGSQLProvider) getRoles(limit int, offset int, order string, minimal bool) ([]Role, error) { return sqlCommonGetRoles(limit, offset, order, minimal, p.dbHandle) } func (p *PGSQLProvider) dumpRoles() ([]Role, error) { return sqlCommonDumpRoles(p.dbHandle) } func (p *PGSQLProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { return sqlCommonGetIPListEntry(ipOrNet, listType, p.dbHandle) } func (p *PGSQLProvider) addIPListEntry(entry *IPListEntry) error { return p.normalizeError(sqlCommonAddIPListEntry(entry, p.dbHandle), fieldIPNet) } func (p *PGSQLProvider) updateIPListEntry(entry *IPListEntry) error { return sqlCommonUpdateIPListEntry(entry, p.dbHandle) } func (p *PGSQLProvider) deleteIPListEntry(entry IPListEntry, softDelete bool) error { return sqlCommonDeleteIPListEntry(entry, softDelete, p.dbHandle) } func (p *PGSQLProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { return sqlCommonGetIPListEntries(listType, filter, from, order, limit, p.dbHandle) } func (p *PGSQLProvider) getRecentlyUpdatedIPListEntries(after int64) ([]IPListEntry, error) { return sqlCommonGetRecentlyUpdatedIPListEntries(after, p.dbHandle) } func (p *PGSQLProvider) dumpIPListEntries() ([]IPListEntry, error) { return sqlCommonDumpIPListEntries(p.dbHandle) } func (p *PGSQLProvider) countIPListEntries(listType IPListType) (int64, error) { return sqlCommonCountIPListEntries(listType, p.dbHandle) } func (p *PGSQLProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { return sqlCommonGetListEntriesForIP(ip, listType, p.dbHandle) } func (p *PGSQLProvider) getConfigs() (Configs, error) { return sqlCommonGetConfigs(p.dbHandle) } func (p *PGSQLProvider) setConfigs(configs *Configs) error { return sqlCommonSetConfigs(configs, p.dbHandle) } func (p *PGSQLProvider) setFirstDownloadTimestamp(username string) error { return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle) } func (p *PGSQLProvider) setFirstUploadTimestamp(username string) error { return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle) } func (p *PGSQLProvider) close() error { return p.dbHandle.Close() } func (p *PGSQLProvider) reloadConfig() error { return nil } // initializeDatabase creates the initial database structure func (p *PGSQLProvider) initializeDatabase() error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false) if err == nil && dbVersion.Version > 0 { return ErrNoInitRequired } if errors.Is(err, sql.ErrNoRows) { return errSchemaVersionEmpty } logger.InfoToConsole("creating initial database schema, version 33") providerLog(logger.LevelInfo, "creating initial database schema, version 33") var initialSQL string if config.Driver == CockroachDataProviderName { initialSQL = sqlReplaceAll(pgsqlInitial) initialSQL = strings.ReplaceAll(initialSQL, "GENERATED ALWAYS AS IDENTITY", "DEFAULT unordered_unique_rowid()") } else { initialSQL = sqlReplaceAll(pgsqlInitial + ipListsLikeIndex) } return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 33, true) } func (p *PGSQLProvider) migrateDatabase() error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err } switch version := dbVersion.Version; { case version == sqlDatabaseVersion: providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version) return ErrNoInitRequired case version < 33: err = errSchemaVersionTooOld(version) providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err case version == 33: return updatePGSQLDatabaseFromV33(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, sqlDatabaseVersion) logger.WarnToConsole("database schema version %d is newer than the supported one: %d", version, sqlDatabaseVersion) return nil } return fmt.Errorf("database schema version not handled: %d", version) } } func (p *PGSQLProvider) revertDatabase(targetVersion int) error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err } if dbVersion.Version == targetVersion { return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { case 34: return downgradePGSQLDatabaseFromV34(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } } func (p *PGSQLProvider) resetDatabase() error { sql := sqlReplaceAll(pgsqlResetSQL) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false) } func (p *PGSQLProvider) normalizeError(err error, fieldType int) error { if err == nil { return nil } var pgsqlErr *pgconn.PgError if errors.As(err, &pgsqlErr) { switch pgsqlErr.Code { case "23505": var message string switch fieldType { case fieldUsername: message = util.I18nErrorDuplicatedUsername case fieldIPNet: message = util.I18nErrorDuplicatedIPNet default: message = util.I18nErrorDuplicatedName } return util.NewI18nError( fmt.Errorf("%w: %s", ErrDuplicatedKey, err.Error()), message, ) case "23503": return fmt.Errorf("%w: %s", ErrForeignKeyViolated, err.Error()) } } return err } func updatePGSQLDatabaseFromV33(dbHandle *sql.DB) error { return updatePGSQLDatabaseFrom33To34(dbHandle) } func downgradePGSQLDatabaseFromV34(dbHandle *sql.DB) error { return downgradePGSQLDatabaseFrom34To33(dbHandle) } func updatePGSQLDatabaseFrom33To34(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 33 -> 34") providerLog(logger.LevelInfo, "updating database schema version: 33 -> 34") sql := strings.ReplaceAll(pgsqlV34SQL, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) sql = strings.ReplaceAll(sql, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 34, true) } func downgradePGSQLDatabaseFrom34To33(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database schema version: 34 -> 33") providerLog(logger.LevelInfo, "downgrading database schema version: 34 -> 33") sql := strings.ReplaceAll(pgsqlV34DownSQL, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 33, false) } ================================================ FILE: internal/dataprovider/pgsql_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build nopgsql package dataprovider import ( "errors" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-pgsql") } func initializePGSQLProvider() error { return errors.New("PostgreSQL disabled at build time") } ================================================ FILE: internal/dataprovider/quotaupdater.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "sync" "time" "github.com/drakkan/sftpgo/v2/internal/logger" ) var delayedQuotaUpdater quotaUpdater func init() { delayedQuotaUpdater = newQuotaUpdater() } type quotaObject struct { size int64 files int } type transferQuotaObject struct { ulSize int64 dlSize int64 } type quotaUpdater struct { paramsMutex sync.RWMutex waitTime time.Duration sync.RWMutex pendingUserQuotaUpdates map[string]quotaObject pendingFolderQuotaUpdates map[string]quotaObject pendingTransferQuotaUpdates map[string]transferQuotaObject } func newQuotaUpdater() quotaUpdater { return quotaUpdater{ pendingUserQuotaUpdates: make(map[string]quotaObject), pendingFolderQuotaUpdates: make(map[string]quotaObject), pendingTransferQuotaUpdates: make(map[string]transferQuotaObject), } } func (q *quotaUpdater) start() { q.setWaitTime(config.DelayedQuotaUpdate) go q.loop() } func (q *quotaUpdater) loop() { waitTime := q.getWaitTime() providerLog(logger.LevelDebug, "delayed quota update loop started, wait time: %v", waitTime) for waitTime > 0 { // We do this with a time.Sleep instead of a time.Ticker because we don't know // how long each quota processing cycle will take, and we want to make // sure we wait the configured seconds between each iteration time.Sleep(waitTime) providerLog(logger.LevelDebug, "delayed quota update check start") q.storeUsersQuota() q.storeFoldersQuota() q.storeUsersTransferQuota() providerLog(logger.LevelDebug, "delayed quota update check end") waitTime = q.getWaitTime() } providerLog(logger.LevelDebug, "delayed quota update loop ended, wait time: %v", waitTime) } func (q *quotaUpdater) setWaitTime(secs int) { q.paramsMutex.Lock() defer q.paramsMutex.Unlock() q.waitTime = time.Duration(secs) * time.Second } func (q *quotaUpdater) getWaitTime() time.Duration { q.paramsMutex.RLock() defer q.paramsMutex.RUnlock() return q.waitTime } func (q *quotaUpdater) resetUserQuota(username string) { q.Lock() defer q.Unlock() delete(q.pendingUserQuotaUpdates, username) } func (q *quotaUpdater) updateUserQuota(username string, files int, size int64) { q.Lock() defer q.Unlock() obj := q.pendingUserQuotaUpdates[username] obj.size += size obj.files += files if obj.files == 0 && obj.size == 0 { delete(q.pendingUserQuotaUpdates, username) return } q.pendingUserQuotaUpdates[username] = obj } func (q *quotaUpdater) getUserPendingQuota(username string) (int, int64) { q.RLock() defer q.RUnlock() obj := q.pendingUserQuotaUpdates[username] return obj.files, obj.size } func (q *quotaUpdater) resetFolderQuota(name string) { q.Lock() defer q.Unlock() delete(q.pendingFolderQuotaUpdates, name) } func (q *quotaUpdater) updateFolderQuota(name string, files int, size int64) { q.Lock() defer q.Unlock() obj := q.pendingFolderQuotaUpdates[name] obj.size += size obj.files += files if obj.files == 0 && obj.size == 0 { delete(q.pendingFolderQuotaUpdates, name) return } q.pendingFolderQuotaUpdates[name] = obj } func (q *quotaUpdater) getFolderPendingQuota(name string) (int, int64) { q.RLock() defer q.RUnlock() obj := q.pendingFolderQuotaUpdates[name] return obj.files, obj.size } func (q *quotaUpdater) resetUserTransferQuota(username string) { q.Lock() defer q.Unlock() delete(q.pendingTransferQuotaUpdates, username) } func (q *quotaUpdater) updateUserTransferQuota(username string, ulSize, dlSize int64) { q.Lock() defer q.Unlock() obj := q.pendingTransferQuotaUpdates[username] obj.ulSize += ulSize obj.dlSize += dlSize if obj.ulSize == 0 && obj.dlSize == 0 { delete(q.pendingTransferQuotaUpdates, username) return } q.pendingTransferQuotaUpdates[username] = obj } func (q *quotaUpdater) getUserPendingTransferQuota(username string) (int64, int64) { q.RLock() defer q.RUnlock() obj := q.pendingTransferQuotaUpdates[username] return obj.ulSize, obj.dlSize } func (q *quotaUpdater) getUsernames() []string { q.RLock() defer q.RUnlock() result := make([]string, 0, len(q.pendingUserQuotaUpdates)) for username := range q.pendingUserQuotaUpdates { result = append(result, username) } return result } func (q *quotaUpdater) getFoldernames() []string { q.RLock() defer q.RUnlock() result := make([]string, 0, len(q.pendingFolderQuotaUpdates)) for name := range q.pendingFolderQuotaUpdates { result = append(result, name) } return result } func (q *quotaUpdater) getTransferQuotaUsernames() []string { q.RLock() defer q.RUnlock() result := make([]string, 0, len(q.pendingTransferQuotaUpdates)) for username := range q.pendingTransferQuotaUpdates { result = append(result, username) } return result } func (q *quotaUpdater) storeUsersQuota() { for _, username := range q.getUsernames() { files, size := q.getUserPendingQuota(username) if size != 0 || files != 0 { err := provider.updateQuota(username, files, size, false) if err != nil { providerLog(logger.LevelWarn, "unable to update quota delayed for user %q: %v", username, err) continue } q.updateUserQuota(username, -files, -size) } } } func (q *quotaUpdater) storeFoldersQuota() { for _, name := range q.getFoldernames() { files, size := q.getFolderPendingQuota(name) if size != 0 || files != 0 { err := provider.updateFolderQuota(name, files, size, false) if err != nil { providerLog(logger.LevelWarn, "unable to update quota delayed for folder %q: %v", name, err) continue } q.updateFolderQuota(name, -files, -size) } } } func (q *quotaUpdater) storeUsersTransferQuota() { for _, username := range q.getTransferQuotaUsernames() { ulSize, dlSize := q.getUserPendingTransferQuota(username) if ulSize != 0 || dlSize != 0 { err := provider.updateTransferQuota(username, ulSize, dlSize, false) if err != nil { providerLog(logger.LevelWarn, "unable to update transfer quota delayed for user %q: %v", username, err) continue } q.updateUserTransferQuota(username, -ulSize, -dlSize) } } } ================================================ FILE: internal/dataprovider/role.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "encoding/json" "fmt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) // Role defines an SFTPGo role. type Role struct { // Data provider unique identifier ID int64 `json:"id"` // Role name Name string `json:"name"` // optional description Description string `json:"description,omitempty"` // Creation time as unix timestamp in milliseconds CreatedAt int64 `json:"created_at"` // last update time as unix timestamp in milliseconds UpdatedAt int64 `json:"updated_at"` // list of admins associated with this role Admins []string `json:"admins,omitempty"` // list of usernames associated with this role Users []string `json:"users,omitempty"` } // RenderAsJSON implements the renderer interface used within plugins func (r *Role) RenderAsJSON(reload bool) ([]byte, error) { if reload { role, err := provider.roleExists(r.Name) if err != nil { providerLog(logger.LevelError, "unable to reload role before rendering as json: %v", err) return nil, err } return json.Marshal(role) } return json.Marshal(r) } func (r *Role) validate() error { if r.Name == "" { return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) } if !util.IsNameValid(r.Name) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if len(r.Name) > 255 { return util.NewValidationError("name is too long, 255 is the maximum length allowed") } if config.NamingRules&1 == 0 && !usernameRegex.MatchString(r.Name) { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", r.Name)), util.I18nErrorInvalidName, ) } return nil } func (r *Role) getACopy() Role { users := make([]string, len(r.Users)) copy(users, r.Users) admins := make([]string, len(r.Admins)) copy(admins, r.Admins) return Role{ ID: r.ID, Name: r.Name, Description: r.Description, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, Users: users, Admins: admins, } } ================================================ FILE: internal/dataprovider/scheduler.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "fmt" "sync/atomic" "time" "github.com/robfig/cron/v3" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( scheduler *cron.Cron lastUserCacheUpdate atomic.Int64 lastIPListsCacheUpdate atomic.Int64 // used for bolt and memory providers, so we avoid iterating all users/rules // to find recently modified ones lastUserUpdate atomic.Int64 lastRuleUpdate atomic.Int64 ) func stopScheduler() { if scheduler != nil { scheduler.Stop() scheduler = nil } } func startScheduler() error { stopScheduler() scheduler = cron.New(cron.WithLocation(time.UTC), cron.WithLogger(cron.DiscardLogger)) _, err := scheduler.AddFunc("@every 55s", checkDataprovider) if err != nil { return fmt.Errorf("unable to schedule dataprovider availability check: %w", err) } err = addScheduledCacheUpdates() if err != nil { return err } if currentNode != nil { _, err = scheduler.AddFunc("@every 30m", func() { err := provider.cleanupNodes() if err != nil { providerLog(logger.LevelError, "unable to cleanup nodes: %v", err) } else { providerLog(logger.LevelDebug, "cleanup nodes ok") } }) } if err != nil { return fmt.Errorf("unable to schedule nodes cleanup: %w", err) } scheduler.Start() return nil } func addScheduledCacheUpdates() error { lastUserCacheUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) lastIPListsCacheUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) _, err := scheduler.AddFunc("@every 10m", checkCacheUpdates) if err != nil { return fmt.Errorf("unable to schedule cache updates: %w", err) } return nil } func checkDataprovider() { if currentNode != nil { err := provider.updateNodeTimestamp() if err != nil { providerLog(logger.LevelError, "unable to update node timestamp: %v", err) } else { providerLog(logger.LevelDebug, "node timestamp updated") } metric.UpdateDataProviderAvailability(err) return } err := provider.checkAvailability() if err != nil { providerLog(logger.LevelError, "check availability error: %v", err) } metric.UpdateDataProviderAvailability(err) } func checkCacheUpdates() { checkUserCache() checkIPListEntryCache() cachedUserPasswords.cleanup() cachedAdminPasswords.cleanup() cachedAPIKeys.cleanup() } func checkUserCache() { lastCheck := lastUserCacheUpdate.Load() providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastCheck)) checkTime := util.GetTimeAsMsSinceEpoch(time.Now()) if config.IsShared == 1 { lastCheck -= 5000 } users, err := provider.getRecentlyUpdatedUsers(lastCheck) if err != nil { providerLog(logger.LevelError, "unable to get recently updated users: %v", err) return } for idx := range users { user := users[idx] providerLog(logger.LevelDebug, "invalidate caches for user %q", user.Username) if user.DeletedAt > 0 { deletedAt := util.GetTimeFromMsecSinceEpoch(user.DeletedAt) if deletedAt.Add(30 * time.Minute).Before(time.Now()) { providerLog(logger.LevelDebug, "removing user %q deleted at %s", user.Username, deletedAt) go provider.deleteUser(user, false) //nolint:errcheck } webDAVUsersCache.remove(user.Username) cachedUserPasswords.Remove(user.Username) delayedQuotaUpdater.resetUserQuota(user.Username) } else { webDAVUsersCache.swap(&user, "") } } lastUserCacheUpdate.Store(checkTime) providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load())) } func checkIPListEntryCache() { if config.IsShared != 1 { return } hasMemoryLists := false for _, l := range inMemoryLists { if l.isInMemory.Load() { hasMemoryLists = true break } } if !hasMemoryLists { return } providerLog(logger.LevelDebug, "start IP list cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastIPListsCacheUpdate.Load())) checkTime := util.GetTimeAsMsSinceEpoch(time.Now()) entries, err := provider.getRecentlyUpdatedIPListEntries(lastIPListsCacheUpdate.Load() - 5000) if err != nil { providerLog(logger.LevelError, "unable to get recently updated IP list entries: %v", err) return } for idx := range entries { e := entries[idx] providerLog(logger.LevelDebug, "update cache for IP list entry %q", e.getName()) if e.DeletedAt > 0 { deletedAt := util.GetTimeFromMsecSinceEpoch(e.DeletedAt) if deletedAt.Add(30 * time.Minute).Before(time.Now()) { providerLog(logger.LevelDebug, "removing IP list entry %q deleted at %s", e.getName(), deletedAt) go provider.deleteIPListEntry(e, false) //nolint:errcheck } for _, l := range inMemoryLists { l.removeEntry(&e) } } else { for _, l := range inMemoryLists { l.updateEntry(&e) } } } lastIPListsCacheUpdate.Store(checkTime) providerLog(logger.LevelDebug, "end IP list entries cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastIPListsCacheUpdate.Load())) } func setLastUserUpdate() { lastUserUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) } func getLastUserUpdate() int64 { return lastUserUpdate.Load() } func setLastRuleUpdate() { lastRuleUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) } func getLastRuleUpdate() int64 { return lastRuleUpdate.Load() } ================================================ FILE: internal/dataprovider/session.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "errors" "fmt" ) // SessionType defines the supported session types type SessionType int // Supported session types const ( SessionTypeOIDCAuth SessionType = iota + 1 SessionTypeOIDCToken SessionTypeResetCode SessionTypeOAuth2Auth SessionTypeInvalidToken SessionTypeWebTask ) // Session defines a shared session persisted in the data provider type Session struct { Key string Data any Type SessionType Timestamp int64 } func (s *Session) validate() error { if s.Key == "" { return errors.New("unable to save a session with an empty key") } if s.Type < SessionTypeOIDCAuth || s.Type > SessionTypeWebTask { return fmt.Errorf("invalid session type: %v", s.Type) } return nil } ================================================ FILE: internal/dataprovider/share.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "encoding/json" "fmt" "net" "strings" "time" "github.com/alexedwards/argon2id" passwordvalidator "github.com/wagslane/go-password-validator" "golang.org/x/crypto/bcrypt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) // ShareScope defines the supported share scopes type ShareScope int // Supported share scopes const ( ShareScopeRead ShareScope = iota + 1 ShareScopeWrite ShareScopeReadWrite ) const ( redactedPassword = "[**redacted**]" ) // Share defines files and or directories shared with external users type Share struct { // Database unique identifier ID int64 `json:"-"` // Unique ID used to access this object ShareID string `json:"id"` Name string `json:"name"` Description string `json:"description,omitempty"` Scope ShareScope `json:"scope"` // Paths to files or directories, for ShareScopeWrite it must be exactly one directory Paths []string `json:"paths"` // Username who shared this object Username string `json:"username"` CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` // 0 means never used LastUseAt int64 `json:"last_use_at,omitempty"` // ExpiresAt expiration date/time as unix timestamp in milliseconds, 0 means no expiration ExpiresAt int64 `json:"expires_at,omitempty"` // Optional password to protect the share Password string `json:"password"` // Limit the available access tokens, 0 means no limit MaxTokens int `json:"max_tokens,omitempty"` // Used tokens UsedTokens int `json:"used_tokens,omitempty"` // Limit the share availability to these IPs/CIDR networks AllowFrom []string `json:"allow_from,omitempty"` // set for restores, we don't have to validate the expiration date // otherwise we fail to restore existing shares and we have to insert // all the previous values with no modifications IsRestore bool `json:"-"` } // IsExpired returns true if the share is expired func (s *Share) IsExpired() bool { if s.ExpiresAt > 0 { return s.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) } return false } // GetAllowedFromAsString returns the allowed IP as comma separated string func (s *Share) GetAllowedFromAsString() string { return strings.Join(s.AllowFrom, ",") } // IsPasswordHashed returns true if the password is hashed func (s *Share) IsPasswordHashed() bool { return util.IsStringPrefixInSlice(s.Password, hashPwdPrefixes) } func (s *Share) getACopy() Share { allowFrom := make([]string, len(s.AllowFrom)) copy(allowFrom, s.AllowFrom) return Share{ ID: s.ID, ShareID: s.ShareID, Name: s.Name, Description: s.Description, Scope: s.Scope, Paths: s.Paths, Username: s.Username, CreatedAt: s.CreatedAt, UpdatedAt: s.UpdatedAt, LastUseAt: s.LastUseAt, ExpiresAt: s.ExpiresAt, Password: s.Password, MaxTokens: s.MaxTokens, UsedTokens: s.UsedTokens, AllowFrom: allowFrom, } } // RenderAsJSON implements the renderer interface used within plugins func (s *Share) RenderAsJSON(reload bool) ([]byte, error) { if reload { share, err := provider.shareExists(s.ShareID, s.Username) if err != nil { providerLog(logger.LevelError, "unable to reload share before rendering as json: %v", err) return nil, err } share.HideConfidentialData() return json.Marshal(share) } s.HideConfidentialData() return json.Marshal(s) } // HideConfidentialData hides share confidential data func (s *Share) HideConfidentialData() { if s.Password != "" { s.Password = redactedPassword } } // HasRedactedPassword returns true if this share has a redacted password func (s *Share) HasRedactedPassword() bool { return s.Password == redactedPassword } func (s *Share) hashPassword() error { if s.Password != "" && !util.IsStringPrefixInSlice(s.Password, internalHashPwdPrefixes) { user, err := GetUserWithGroupSettings(s.Username, "") if err != nil { return util.NewGenericError(fmt.Sprintf("unable to validate user: %v", err)) } if minEntropy := user.getMinPasswordEntropy(); minEntropy > 0 { if err := passwordvalidator.Validate(s.Password, minEntropy); err != nil { return util.NewI18nError(util.NewValidationError(err.Error()), util.I18nErrorPasswordComplexity) } } if config.PasswordHashing.Algo == HashingAlgoBcrypt { hashed, err := bcrypt.GenerateFromPassword([]byte(s.Password), config.PasswordHashing.BcryptOptions.Cost) if err != nil { return err } s.Password = util.BytesToString(hashed) } else { hashed, err := argon2id.CreateHash(s.Password, argon2Params) if err != nil { return err } s.Password = hashed } } return nil } func (s *Share) validatePaths() error { var paths []string for _, p := range s.Paths { if strings.TrimSpace(p) != "" { paths = append(paths, p) } } s.Paths = paths if len(s.Paths) == 0 { return util.NewI18nError(util.NewValidationError("at least a shared path is required"), util.I18nErrorSharePathRequired) } for idx := range s.Paths { s.Paths[idx] = util.CleanPath(s.Paths[idx]) } s.Paths = util.RemoveDuplicates(s.Paths, false) if s.Scope >= ShareScopeWrite && len(s.Paths) != 1 { return util.NewI18nError(util.NewValidationError("the write share scope requires exactly one path"), util.I18nErrorShareWriteScope) } // check nested paths if len(s.Paths) > 1 { for idx := range s.Paths { for innerIdx := range s.Paths { if idx == innerIdx { continue } if s.Paths[idx] == "/" || s.Paths[innerIdx] == "/" || util.IsDirOverlapped(s.Paths[idx], s.Paths[innerIdx], true, "/") { return util.NewI18nError(util.NewGenericError("shared paths cannot be nested"), util.I18nErrorShareNestedPaths) } } } } return nil } func (s *Share) validate() error { //nolint:gocyclo if s.ShareID == "" { return util.NewValidationError("share_id is mandatory") } if s.Name == "" { return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) } if !util.IsNameValid(s.Name) { return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) } if s.Scope < ShareScopeRead || s.Scope > ShareScopeReadWrite { return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid scope: %v", s.Scope)), util.I18nErrorShareScope) } if err := s.validatePaths(); err != nil { return err } if s.ExpiresAt > 0 { if !s.IsRestore && s.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) { return util.NewI18nError(util.NewValidationError("expiration must be in the future"), util.I18nErrorShareExpirationPast) } } else { s.ExpiresAt = 0 } if s.MaxTokens < 0 { return util.NewI18nError(util.NewValidationError("invalid max tokens"), util.I18nErrorShareMaxTokens) } if s.Username == "" { return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired) } if s.HasRedactedPassword() { return util.NewValidationError("cannot save a share with a redacted password") } if err := s.hashPassword(); err != nil { return err } s.AllowFrom = util.RemoveDuplicates(s.AllowFrom, false) for _, IPMask := range s.AllowFrom { _, _, err := net.ParseCIDR(IPMask) if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not parse allow from entry %q : %v", IPMask, err)), util.I18nErrorInvalidIPMask, ) } } return nil } // CheckCredentials verifies the share credentials if a password if set func (s *Share) CheckCredentials(password string) (bool, error) { if s.Password == "" { return true, nil } if password == "" { return false, ErrInvalidCredentials } if strings.HasPrefix(s.Password, bcryptPwdPrefix) { if err := bcrypt.CompareHashAndPassword([]byte(s.Password), []byte(password)); err != nil { return false, ErrInvalidCredentials } return true, nil } match, err := argon2id.ComparePasswordAndHash(password, s.Password) if !match || err != nil { return false, ErrInvalidCredentials } return match, err } // GetRelativePath returns the specified absolute path as relative to the share base path func (s *Share) GetRelativePath(name string) string { if len(s.Paths) == 0 { return "" } return util.CleanPath(strings.TrimPrefix(name, s.Paths[0])) } // IsUsable checks if the share is usable from the specified IP func (s *Share) IsUsable(ip string) (bool, error) { if s.MaxTokens > 0 && s.UsedTokens >= s.MaxTokens { return false, util.NewI18nError(util.NewRecordNotFoundError("max share usage exceeded"), util.I18nErrorShareUsage) } if s.ExpiresAt > 0 { if s.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) { return false, util.NewI18nError(util.NewRecordNotFoundError("share expired"), util.I18nErrorShareExpired) } } if len(s.AllowFrom) == 0 { return true, nil } parsedIP := net.ParseIP(ip) if parsedIP == nil { return false, util.NewI18nError(ErrLoginNotAllowedFromIP, util.I18nErrorLoginFromIPDenied) } for _, ipMask := range s.AllowFrom { _, network, err := net.ParseCIDR(ipMask) if err != nil { continue } if network.Contains(parsedIP) { return true, nil } } return false, util.NewI18nError(ErrLoginNotAllowedFromIP, util.I18nErrorLoginFromIPDenied) } ================================================ FILE: internal/dataprovider/sqlcommon.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "context" "crypto/x509" "database/sql" "encoding/json" "errors" "fmt" "net/netip" "runtime/debug" "strconv" "strings" "time" "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( sqlDatabaseVersion = 34 defaultSQLQueryTimeout = 10 * time.Second longSQLQueryTimeout = 60 * time.Second ) var ( errSQLFoldersAssociation = errors.New("unable to associate virtual folders to user") errSQLGroupsAssociation = errors.New("unable to associate groups to user") errSQLUsersAssociation = errors.New("unable to associate users to group") errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command") ) type sqlQuerier interface { QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) } type sqlScanner interface { Scan(dest ...any) error } func sqlReplaceAll(sql string) string { sql = strings.ReplaceAll(sql, "{{schema_version}}", sqlTableSchemaVersion) sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins) sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders) sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers) sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping) sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys) sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) sql = strings.ReplaceAll(sql, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents) sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) sql = strings.ReplaceAll(sql, "{{shared_sessions}}", sqlTableSharedSessions) sql = strings.ReplaceAll(sql, "{{events_actions}}", sqlTableEventsActions) sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules) sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping) sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks) sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes) sql = strings.ReplaceAll(sql, "{{roles}}", sqlTableRoles) sql = strings.ReplaceAll(sql, "{{ip_lists}}", sqlTableIPLists) sql = strings.ReplaceAll(sql, "{{configs}}", sqlTableConfigs) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) return sql } func sqlCommonGetShareByID(shareID, username string, dbHandle sqlQuerier) (Share, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() filterUser := username != "" q := getShareByIDQuery(filterUser) var row *sql.Row if filterUser { row = dbHandle.QueryRowContext(ctx, q, shareID, username) } else { row = dbHandle.QueryRowContext(ctx, q, shareID) } return getShareFromDbRow(row) } func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error { err := share.validate() if err != nil { return err } user, err := provider.userExists(share.Username, "") if err != nil { return util.NewGenericError(fmt.Sprintf("unable to validate user %q", share.Username)) } paths, err := json.Marshal(share.Paths) if err != nil { return err } var allowFrom []byte if len(share.AllowFrom) > 0 { res, err := json.Marshal(share.AllowFrom) if err == nil { allowFrom = res } } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddShareQuery() usedTokens := 0 createdAt := util.GetTimeAsMsSinceEpoch(time.Now()) updatedAt := createdAt lastUseAt := int64(0) if share.IsRestore { usedTokens = share.UsedTokens if share.CreatedAt > 0 { createdAt = share.CreatedAt } if share.UpdatedAt > 0 { updatedAt = share.UpdatedAt } lastUseAt = share.LastUseAt } _, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope, paths, createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password, share.MaxTokens, usedTokens, allowFrom, user.ID) return err } func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error { err := share.validate() if err != nil { return err } paths, err := json.Marshal(share.Paths) if err != nil { return err } var allowFrom []byte if len(share.AllowFrom) > 0 { res, err := json.Marshal(share.AllowFrom) if err == nil { allowFrom = res } } user, err := provider.userExists(share.Username, "") if err != nil { return util.NewGenericError(fmt.Sprintf("unable to validate user %q", share.Username)) } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() var q string if share.IsRestore { q = getUpdateShareRestoreQuery() } else { q = getUpdateShareQuery() } var res sql.Result if share.IsRestore { if share.CreatedAt == 0 { share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) } if share.UpdatedAt == 0 { share.UpdatedAt = share.CreatedAt } res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths, share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens, share.UsedTokens, allowFrom, user.ID, share.ShareID) } else { res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths, util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens, allowFrom, user.ID, share.ShareID) } if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteShareQuery() res, err := dbHandle.ExecContext(ctx, q, share.ShareID) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonGetShares(limit, offset int, order, username string, dbHandle sqlQuerier) ([]Share, error) { shares := make([]Share, 0, limit) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getSharesQuery(order) rows, err := dbHandle.QueryContext(ctx, q, username, limit, offset) if err != nil { return shares, err } defer rows.Close() for rows.Next() { s, err := getShareFromDbRow(rows) if err != nil { return shares, err } s.HideConfidentialData() shares = append(shares, s) } return shares, rows.Err() } func sqlCommonDumpShares(dbHandle sqlQuerier) ([]Share, error) { shares := make([]Share, 0, 30) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDumpSharesQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return shares, err } defer rows.Close() for rows.Next() { s, err := getShareFromDbRow(rows) if err != nil { return shares, err } shares = append(shares, s) } return shares, rows.Err() } func sqlCommonGetAPIKeyByID(keyID string, dbHandle sqlQuerier) (APIKey, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAPIKeyByIDQuery() row := dbHandle.QueryRowContext(ctx, q, keyID) apiKey, err := getAPIKeyFromDbRow(row) if err != nil { return apiKey, err } return getAPIKeyWithRelatedFields(ctx, apiKey, dbHandle) } func sqlCommonAddAPIKey(apiKey *APIKey, dbHandle *sql.DB) error { err := apiKey.validate() if err != nil { return err } userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddAPIKeyQuery() _, err = dbHandle.ExecContext(ctx, q, apiKey.KeyID, apiKey.Name, apiKey.Key, apiKey.Scope, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.LastUseAt, apiKey.ExpiresAt, apiKey.Description, userID, adminID) return err } func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error { err := apiKey.validate() if err != nil { return err } userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateAPIKeyQuery() res, err := dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID, apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteAPIKeyQuery() res, err := dbHandle.ExecContext(ctx, q, apiKey.KeyID) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) { apiKeys := make([]APIKey, 0, limit) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAPIKeysQuery(order) rows, err := dbHandle.QueryContext(ctx, q, limit, offset) if err != nil { return apiKeys, err } defer rows.Close() for rows.Next() { k, err := getAPIKeyFromDbRow(rows) if err != nil { return apiKeys, err } k.HideConfidentialData() apiKeys = append(apiKeys, k) } err = rows.Err() if err != nil { return apiKeys, err } apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin) if err != nil { return apiKeys, err } return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser) } func sqlCommonDumpAPIKeys(dbHandle sqlQuerier) ([]APIKey, error) { apiKeys := make([]APIKey, 0, 30) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDumpAPIKeysQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return apiKeys, err } defer rows.Close() for rows.Next() { k, err := getAPIKeyFromDbRow(rows) if err != nil { return apiKeys, err } apiKeys = append(apiKeys, k) } err = rows.Err() if err != nil { return apiKeys, err } apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin) if err != nil { return apiKeys, err } return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser) } func sqlCommonGetAdminByUsername(username string, dbHandle sqlQuerier) (Admin, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAdminByUsernameQuery() row := dbHandle.QueryRowContext(ctx, q, username) admin, err := getAdminFromDbRow(row) if err != nil { return admin, err } return getAdminWithGroups(ctx, admin, dbHandle) } func sqlCommonValidateAdminAndPass(username, password, ip string, dbHandle *sql.DB) (Admin, error) { admin, err := sqlCommonGetAdminByUsername(username, dbHandle) if err != nil { providerLog(logger.LevelWarn, "error authenticating admin %q: %v", username, err) return admin, err } err = admin.checkUserAndPass(password, ip) return admin, err } func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error { err := admin.validate() if err != nil { return err } perms, err := json.Marshal(admin.Permissions) if err != nil { return err } filters, err := json.Marshal(admin.Filters) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getAddAdminQuery(admin.Role) _, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, perms, filters, admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role) if err != nil { return err } return generateAdminGroupMapping(ctx, admin, tx) }) } func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error { err := admin.validate() if err != nil { return err } perms, err := json.Marshal(admin.Permissions) if err != nil { return err } filters, err := json.Marshal(admin.Filters) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateAdminQuery(admin.Role) _, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, perms, filters, admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role, admin.Username) if err != nil { return err } return generateAdminGroupMapping(ctx, admin, tx) }) } func sqlCommonDeleteAdmin(admin Admin, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteAdminQuery() res, err := dbHandle.ExecContext(ctx, q, admin.Username) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonGetAdmins(limit, offset int, order string, dbHandle sqlQuerier) ([]Admin, error) { admins := make([]Admin, 0, limit) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAdminsQuery(order) rows, err := dbHandle.QueryContext(ctx, q, limit, offset) if err != nil { return admins, err } defer rows.Close() for rows.Next() { a, err := getAdminFromDbRow(rows) if err != nil { return admins, err } a.HideConfidentialData() admins = append(admins, a) } err = rows.Err() if err != nil { return admins, err } return getAdminsWithGroups(ctx, admins, dbHandle) } func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) { admins := make([]Admin, 0, 30) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDumpAdminsQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return admins, err } defer rows.Close() for rows.Next() { a, err := getAdminFromDbRow(rows) if err != nil { return admins, err } admins = append(admins, a) } err = rows.Err() if err != nil { return admins, err } return getAdminsWithGroups(ctx, admins, dbHandle) } func sqlCommonGetIPListEntry(ipOrNet string, listType IPListType, dbHandle sqlQuerier) (IPListEntry, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getIPListEntryQuery() row := dbHandle.QueryRowContext(ctx, q, listType, ipOrNet) return getIPListEntryFromDbRow(row) } func sqlCommonDumpIPListEntries(dbHandle *sql.DB) ([]IPListEntry, error) { count, err := sqlCommonCountIPListEntries(0, dbHandle) if err != nil { return nil, err } if count > ipListMemoryLimit { providerLog(logger.LevelInfo, "IP lists excluded from dump, too many entries: %d", count) return nil, nil } entries := make([]IPListEntry, 0, 100) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getDumpListEntriesQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return entries, err } defer rows.Close() for rows.Next() { entry, err := getIPListEntryFromDbRow(rows) if err != nil { return entries, err } entries = append(entries, entry) } return entries, rows.Err() } func sqlCommonCountIPListEntries(listType IPListType, dbHandle *sql.DB) (int64, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() var q string var args []any if listType == 0 { q = getCountAllIPListEntriesQuery() } else { q = getCountIPListEntriesQuery() args = append(args, listType) } var count int64 err := dbHandle.QueryRowContext(ctx, q, args...).Scan(&count) return count, err } func sqlCommonGetIPListEntries(listType IPListType, filter, from, order string, limit int, dbHandle sqlQuerier) ([]IPListEntry, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getIPListEntriesQuery(filter, from, order, limit) args := []any{listType} if from != "" { args = append(args, from) } if filter != "" { args = append(args, filter+"%") } if limit > 0 { args = append(args, limit) } entries := make([]IPListEntry, 0, limit) rows, err := dbHandle.QueryContext(ctx, q, args...) if err != nil { return entries, err } defer rows.Close() for rows.Next() { entry, err := getIPListEntryFromDbRow(rows) if err != nil { return entries, err } entries = append(entries, entry) } return entries, rows.Err() } func sqlCommonGetRecentlyUpdatedIPListEntries(after int64, dbHandle sqlQuerier) ([]IPListEntry, error) { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getRecentlyUpdatedIPListQuery() entries := make([]IPListEntry, 0, 5) rows, err := dbHandle.QueryContext(ctx, q, after) if err != nil { return entries, err } defer rows.Close() for rows.Next() { entry, err := getIPListEntryFromDbRow(rows) if err != nil { return entries, err } entries = append(entries, entry) } return entries, rows.Err() } func sqlCommonGetListEntriesForIP(ip string, listType IPListType, dbHandle sqlQuerier) ([]IPListEntry, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() var rows *sql.Rows var err error entries := make([]IPListEntry, 0, 2) if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { rows, err = dbHandle.QueryContext(ctx, getIPListEntriesForIPQueryPg(), listType, ip) if err != nil { return entries, err } } else { ipAddr, err := netip.ParseAddr(ip) if err != nil { return entries, fmt.Errorf("invalid ip address %s", ip) } var netType int var ipBytes []byte if ipAddr.Is4() || ipAddr.Is4In6() { netType = ipTypeV4 as4 := ipAddr.As4() ipBytes = as4[:] } else { netType = ipTypeV6 as16 := ipAddr.As16() ipBytes = as16[:] } rows, err = dbHandle.QueryContext(ctx, getIPListEntriesForIPQueryNoPg(), listType, netType, ipBytes) if err != nil { return entries, err } } defer rows.Close() for rows.Next() { entry, err := getIPListEntryFromDbRow(rows) if err != nil { return entries, err } entries = append(entries, entry) } return entries, rows.Err() } func sqlCommonAddIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error { if err := entry.validate(); err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() var err error q := getAddIPListEntryQuery() first := entry.getFirst() last := entry.getLast() var netType int if first.Is4() { netType = ipTypeV4 } else { netType = ipTypeV6 } if config.IsShared == 1 { return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { _, err := tx.ExecContext(ctx, getRemoveSoftDeletedIPListEntryQuery(), entry.Type, entry.IPOrNet) if err != nil { return err } if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { _, err = tx.ExecContext(ctx, q, entry.Type, entry.IPOrNet, first.String(), last.String(), netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now())) } else { _, err = tx.ExecContext(ctx, q, entry.Type, entry.IPOrNet, entry.First, entry.Last, netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now())) } return err }) } if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { _, err = dbHandle.ExecContext(ctx, q, entry.Type, entry.IPOrNet, first.String(), last.String(), netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now())) } else { _, err = dbHandle.ExecContext(ctx, q, entry.Type, entry.IPOrNet, entry.First, entry.Last, netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now())) } return err } func sqlCommonUpdateIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error { if err := entry.validate(); err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateIPListEntryQuery() res, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description, util.GetTimeAsMsSinceEpoch(time.Now()), entry.Type, entry.IPOrNet) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteIPListEntry(entry IPListEntry, softDelete bool, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteIPListEntryQuery(softDelete) var args []any if softDelete { ts := util.GetTimeAsMsSinceEpoch(time.Now()) args = append(args, ts, ts) } args = append(args, entry.Type, entry.IPOrNet) res, err := dbHandle.ExecContext(ctx, q, args...) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonGetRoleByName(name string, dbHandle sqlQuerier) (Role, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getRoleByNameQuery() row := dbHandle.QueryRowContext(ctx, q, name) role, err := getRoleFromDbRow(row) if err != nil { return role, err } role, err = getRoleWithUsers(ctx, role, dbHandle) if err != nil { return role, err } return getRoleWithAdmins(ctx, role, dbHandle) } func sqlCommonDumpRoles(dbHandle sqlQuerier) ([]Role, error) { roles := make([]Role, 0, 10) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getDumpRolesQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return roles, err } defer rows.Close() for rows.Next() { role, err := getRoleFromDbRow(rows) if err != nil { return roles, err } roles = append(roles, role) } return roles, rows.Err() } func sqlCommonGetRoles(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Role, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getRolesQuery(order, minimal) roles := make([]Role, 0, limit) rows, err := dbHandle.QueryContext(ctx, q, limit, offset) if err != nil { return roles, err } defer rows.Close() for rows.Next() { var role Role if minimal { err = rows.Scan(&role.ID, &role.Name) } else { role, err = getRoleFromDbRow(rows) } if err != nil { return roles, err } roles = append(roles, role) } err = rows.Err() if err != nil { return roles, err } if minimal { return roles, nil } roles, err = getRolesWithUsers(ctx, roles, dbHandle) if err != nil { return roles, err } return getRolesWithAdmins(ctx, roles, dbHandle) } func sqlCommonAddRole(role *Role, dbHandle *sql.DB) error { if err := role.validate(); err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddRoleQuery() _, err := dbHandle.ExecContext(ctx, q, role.Name, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now())) return err } func sqlCommonUpdateRole(role *Role, dbHandle *sql.DB) error { if err := role.validate(); err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateRoleQuery() res, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteRole(role Role, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteRoleQuery() res, err := dbHandle.ExecContext(ctx, q, role.Name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonGetGroupByName(name string, dbHandle sqlQuerier) (Group, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getGroupByNameQuery() row := dbHandle.QueryRowContext(ctx, q, name) group, err := getGroupFromDbRow(row) if err != nil { return group, err } group, err = getGroupWithVirtualFolders(ctx, group, dbHandle) if err != nil { return group, err } group, err = getGroupWithUsers(ctx, group, dbHandle) if err != nil { return group, err } return getGroupWithAdmins(ctx, group, dbHandle) } func sqlCommonDumpGroups(dbHandle sqlQuerier) ([]Group, error) { groups := make([]Group, 0, 50) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getDumpGroupsQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return groups, err } defer rows.Close() for rows.Next() { group, err := getGroupFromDbRow(rows) if err != nil { return groups, err } groups = append(groups, group) } err = rows.Err() if err != nil { return groups, err } return getGroupsWithVirtualFolders(ctx, groups, dbHandle) } func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, error) { if len(names) == 0 { return nil, nil } maxNames := len(sqlPlaceholders) usernames := make([]string, 0, len(names)) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() for len(names) > 0 { if maxNames > len(names) { maxNames = len(names) } q := getUsersInGroupsQuery(maxNames) args := make([]any, 0, maxNames) for _, name := range names[:maxNames] { args = append(args, name) } rows, err := dbHandle.QueryContext(ctx, q, args...) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var username string err = rows.Scan(&username) if err != nil { return usernames, err } usernames = append(usernames, username) } err = rows.Err() if err != nil { return usernames, err } names = names[maxNames:] } return usernames, nil } func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group, error) { if len(names) == 0 { return nil, nil } maxNames := len(sqlPlaceholders) groups := make([]Group, 0, len(names)) for len(names) > 0 { if maxNames > len(names) { maxNames = len(names) } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getGroupsWithNamesQuery(maxNames) args := make([]any, 0, maxNames) for _, name := range names[:maxNames] { args = append(args, name) } rows, err := dbHandle.QueryContext(ctx, q, args...) if err != nil { return groups, err } defer rows.Close() for rows.Next() { group, err := getGroupFromDbRow(rows) if err != nil { return groups, err } groups = append(groups, group) } err = rows.Err() if err != nil { return groups, err } names = names[maxNames:] } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return getGroupsWithVirtualFolders(ctx, groups, dbHandle) } func sqlCommonGetGroups(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Group, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getGroupsQuery(order, minimal) groups := make([]Group, 0, limit) rows, err := dbHandle.QueryContext(ctx, q, limit, offset) if err != nil { return groups, err } defer rows.Close() for rows.Next() { var group Group if minimal { err = rows.Scan(&group.ID, &group.Name) } else { group, err = getGroupFromDbRow(rows) } if err != nil { return groups, err } groups = append(groups, group) } err = rows.Err() if err != nil { return groups, err } if minimal { return groups, nil } groups, err = getGroupsWithVirtualFolders(ctx, groups, dbHandle) if err != nil { return groups, err } groups, err = getGroupsWithUsers(ctx, groups, dbHandle) if err != nil { return groups, err } groups, err = getGroupsWithAdmins(ctx, groups, dbHandle) if err != nil { return groups, err } for idx := range groups { groups[idx].PrepareForRendering() } return groups, nil } func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error { if err := group.validate(); err != nil { return err } settings, err := json.Marshal(group.UserSettings) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getAddGroupQuery() _, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), settings) if err != nil { return err } return generateGroupVirtualFoldersMapping(ctx, group, tx) }) } func sqlCommonUpdateGroup(group *Group, dbHandle *sql.DB) error { if err := group.validate(); err != nil { return err } settings, err := json.Marshal(group.UserSettings) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateGroupQuery() _, err := tx.ExecContext(ctx, q, group.Description, settings, util.GetTimeAsMsSinceEpoch(time.Now()), group.Name) if err != nil { return err } return generateGroupVirtualFoldersMapping(ctx, group, tx) }) } func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteGroupQuery() res, err := dbHandle.ExecContext(ctx, q, group.Name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonGetUserByUsername(username, role string, dbHandle sqlQuerier) (User, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUserByUsernameQuery(role) args := []any{username} if role != "" { args = append(args, role) } row := dbHandle.QueryRowContext(ctx, q, args...) user, err := getUserFromDbRow(row) if err != nil { return user, err } user, err = getUserWithVirtualFolders(ctx, user, dbHandle) if err != nil { return user, err } return getUserWithGroups(ctx, user, dbHandle) } func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) { user, err := sqlCommonGetUserByUsername(username, "", dbHandle) if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, err } return checkUserAndPass(&user, password, ip, protocol) } func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *x509.Certificate, dbHandle *sql.DB) (User, error) { var user User if tlsCert == nil { return user, errors.New("TLS certificate cannot be null or empty") } user, err := sqlCommonGetUserByUsername(username, "", dbHandle) if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, err } return checkUserAndTLSCertificate(&user, protocol, tlsCert) } func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bool, dbHandle *sql.DB) (User, string, error) { var user User if len(pubKey) == 0 { return user, "", errors.New("credentials cannot be null or empty") } user, err := sqlCommonGetUserByUsername(username, "", dbHandle) if err != nil { providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) return user, "", err } return checkUserAndPubKey(&user, pubKey, isSSHCert) } func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) { defer func() { if r := recover(); r != nil { providerLog(logger.LevelError, "panic in check provider availability, stack trace: %s", string(debug.Stack())) err = errors.New("unable to check provider status") } }() ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() err = dbHandle.PingContext(ctx) return } func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int64, reset bool, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateTransferQuotaQuery(reset) _, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %d dl increment: %d is reset? %t", username, uploadSize, downloadSize, reset) } else { providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err) } return err } func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateQuotaQuery(reset) _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %d size increment: %d is reset? %t", username, filesAdd, sizeAdd, reset) } else { providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err) } return err } func sqlCommonGetAdminSignature(username string, dbHandle *sql.DB) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAdminSignatureQuery() var updatedAt int64 err := dbHandle.QueryRowContext(ctx, q, username).Scan(&updatedAt) if err != nil { return "", err } return strconv.FormatInt(updatedAt, 10), nil } func sqlCommonGetUserSignature(username string, dbHandle *sql.DB) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUserSignatureQuery() var updatedAt int64 err := dbHandle.QueryRowContext(ctx, q, username).Scan(&updatedAt) if err != nil { return "", err } return strconv.FormatInt(updatedAt, 10), nil } func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, int64, int64, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getQuotaQuery() var usedFiles int var usedSize, usedUploadSize, usedDownloadSize int64 err := dbHandle.QueryRowContext(ctx, q, username).Scan(&usedSize, &usedFiles, &usedUploadSize, &usedDownloadSize) if err != nil { providerLog(logger.LevelError, "error getting quota for user: %v, error: %v", username, err) return 0, 0, 0, 0, err } return usedFiles, usedSize, usedUploadSize, usedDownloadSize, err } func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateShareLastUseQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID) if err == nil { providerLog(logger.LevelDebug, "last use updated for shared object %q", shareID) } else { providerLog(logger.LevelWarn, "error updating last use for shared object %q: %v", shareID, err) } return err } func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateAPIKeyLastUseQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID) if err == nil { providerLog(logger.LevelDebug, "last use updated for key %q", keyID) } else { providerLog(logger.LevelWarn, "error updating last use for key %q: %v", keyID, err) } return err } func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateAdminLastLoginQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { providerLog(logger.LevelDebug, "last login updated for admin %q", username) } else { providerLog(logger.LevelWarn, "error updating last login for admin %q: %v", username, err) } return err } func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getSetUpdateAtQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { providerLog(logger.LevelDebug, "updated_at set for user %q", username) } else { providerLog(logger.LevelWarn, "error setting updated_at for user %q: %v", username, err) } } func sqlCommonSetFirstDownloadTimestamp(username string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getSetFirstDownloadQuery() res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonSetFirstUploadTimestamp(username string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getSetFirstUploadQuery() res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateLastLoginQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { providerLog(logger.LevelDebug, "last login updated for user %q", username) } else { providerLog(logger.LevelWarn, "error updating last login for user %q: %v", username, err) } return err } func sqlCommonAddUser(user *User, dbHandle *sql.DB) error { err := ValidateUser(user) if err != nil { return err } permissions, err := user.GetPermissionsAsJSON() if err != nil { return err } publicKeys, err := user.GetPublicKeysAsJSON() if err != nil { return err } filters, err := user.GetFiltersAsJSON() if err != nil { return err } fsConfig, err := user.GetFsConfigAsJSON() if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { if config.IsShared == 1 { _, err := tx.ExecContext(ctx, getRemoveSoftDeletedUserQuery(), user.Username) if err != nil { return err } } q := getAddUserQuery(user.Role) _, err := tx.ExecContext(ctx, q, user.Username, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role, user.LastPasswordChange) if err != nil { return err } if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil { return err } return generateUserGroupMapping(ctx, user, tx) }) } func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateUserPasswordQuery() res, err := dbHandle.ExecContext(ctx, q, password, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error { err := ValidateUser(user) if err != nil { return err } permissions, err := user.GetPermissionsAsJSON() if err != nil { return err } publicKeys, err := user.GetPublicKeysAsJSON() if err != nil { return err } filters, err := user.GetFiltersAsJSON() if err != nil { return err } fsConfig, err := user.GetFsConfigAsJSON() if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateUserQuery(user.Role) res, err := tx.ExecContext(ctx, q, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role, user.LastPasswordChange, user.Username) if err != nil { return err } if err := sqlCommonRequireRowAffected(res); err != nil { return err } if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil { return err } return generateUserGroupMapping(ctx, user, tx) }) } func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteUserQuery(softDelete) if softDelete { return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { if err := sqlCommonClearUserFolderMapping(ctx, &user, tx); err != nil { return err } if err := sqlCommonClearUserGroupMapping(ctx, &user, tx); err != nil { return err } ts := util.GetTimeAsMsSinceEpoch(time.Now()) res, err := tx.ExecContext(ctx, q, ts, ts, user.Username) if err != nil { return err } return sqlCommonRequireRowAffected(res) }) } res, err := dbHandle.ExecContext(ctx, q, user.Username) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) { users := make([]User, 0, 100) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getDumpUsersQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return users, err } defer rows.Close() for rows.Next() { u, err := getUserFromDbRow(rows) if err != nil { return users, err } users = append(users, u) } err = rows.Err() if err != nil { return users, err } users, err = getUsersWithVirtualFolders(ctx, users, dbHandle) if err != nil { return users, err } return getUsersWithGroups(ctx, users, dbHandle) } func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User, error) { users := make([]User, 0, 10) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getRecentlyUpdatedUsersQuery() rows, err := dbHandle.QueryContext(ctx, q, after) if err != nil { return users, err } defer rows.Close() for rows.Next() { u, err := getUserFromDbRow(rows) if err != nil { return users, err } users = append(users, u) } err = rows.Err() if err != nil { return users, err } users, err = getUsersWithVirtualFolders(ctx, users, dbHandle) if err != nil { return users, err } users, err = getUsersWithGroups(ctx, users, dbHandle) if err != nil { return users, err } var groupNames []string for _, u := range users { for _, g := range u.Groups { groupNames = append(groupNames, g.Name) } } groupNames = util.RemoveDuplicates(groupNames, false) if len(groupNames) == 0 { return users, nil } groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle) if err != nil { return users, err } if len(groups) == 0 { return users, nil } groupsMapping := make(map[string]Group) for idx := range groups { groupsMapping[groups[idx].Name] = groups[idx] } for idx := range users { ref := &users[idx] ref.applyGroupSettings(groupsMapping) } return users, nil } func sqlGetMaxUsersForQuotaCheckRange() int { maxUsers := 50 if maxUsers > len(sqlPlaceholders) { maxUsers = len(sqlPlaceholders) } return maxUsers } func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) { maxUsers := sqlGetMaxUsersForQuotaCheckRange() users := make([]User, 0, maxUsers) usernames := make([]string, 0, len(toFetch)) for k := range toFetch { usernames = append(usernames, k) } for len(usernames) > 0 { if maxUsers > len(usernames) { maxUsers = len(usernames) } usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle) if err != nil { return users, err } users = append(users, usersRange...) usernames = usernames[maxUsers:] } var usersWithFolders []User validIdx := 0 for _, user := range users { if toFetch[user.Username] { usersWithFolders = append(usersWithFolders, user) } else { users[validIdx] = user validIdx++ } } users = users[:validIdx] ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle) if err != nil { return users, err } users = append(users, usersWithFolders...) users, err = getUsersWithGroups(ctx, users, dbHandle) if err != nil { return users, err } var groupNames []string for _, u := range users { for _, g := range u.Groups { groupNames = append(groupNames, g.Name) } } groupNames = util.RemoveDuplicates(groupNames, false) if len(groupNames) == 0 { return users, nil } groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle) if err != nil { return users, err } groupsMapping := make(map[string]Group) for idx := range groups { groupsMapping[groups[idx].Name] = groups[idx] } for idx := range users { ref := &users[idx] ref.applyGroupSettings(groupsMapping) } return users, nil } func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) { users := make([]User, 0, len(usernames)) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUsersForQuotaCheckQuery(len(usernames)) queryArgs := make([]any, 0, len(usernames)) for idx := range usernames { queryArgs = append(queryArgs, usernames[idx]) } rows, err := dbHandle.QueryContext(ctx, q, queryArgs...) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var user User var filters []byte err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize, &user.TotalDataTransfer, &user.UploadDataTransfer, &user.DownloadDataTransfer, &user.UsedUploadDataTransfer, &user.UsedDownloadDataTransfer, &filters) if err != nil { return users, err } var userFilters UserFilters err = json.Unmarshal(filters, &userFilters) if err == nil { user.Filters = userFilters } users = append(users, user) } return users, rows.Err() } func sqlCommonAddActiveTransfer(transfer ActiveTransfer, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddActiveTransferQuery() now := util.GetTimeAsMsSinceEpoch(time.Now()) _, err := dbHandle.ExecContext(ctx, q, transfer.ID, transfer.ConnID, transfer.Type, transfer.Username, transfer.FolderName, transfer.IP, transfer.TruncatedSize, transfer.CurrentULSize, transfer.CurrentDLSize, now, now) return err } func sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateActiveTransferSizesQuery() _, err := dbHandle.ExecContext(ctx, q, ulSize, dlSize, util.GetTimeAsMsSinceEpoch(time.Now()), connectionID, transferID) return err } func sqlCommonRemoveActiveTransfer(transferID int64, connectionID string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getRemoveActiveTransferQuery() _, err := dbHandle.ExecContext(ctx, q, connectionID, transferID) return err } func sqlCommonCleanupActiveTransfers(before time.Time, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getCleanupActiveTransfersQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(before)) return err } func sqlCommonGetActiveTransfers(from time.Time, dbHandle sqlQuerier) ([]ActiveTransfer, error) { transfers := make([]ActiveTransfer, 0, 30) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getActiveTransfersQuery() rows, err := dbHandle.QueryContext(ctx, q, util.GetTimeAsMsSinceEpoch(from)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var transfer ActiveTransfer var folderName sql.NullString err = rows.Scan(&transfer.ID, &transfer.ConnID, &transfer.Type, &transfer.Username, &folderName, &transfer.IP, &transfer.TruncatedSize, &transfer.CurrentULSize, &transfer.CurrentDLSize, &transfer.CreatedAt, &transfer.UpdatedAt) if err != nil { return transfers, err } if folderName.Valid { transfer.FolderName = folderName.String } transfers = append(transfers, transfer) } return transfers, rows.Err() } func sqlCommonGetUsers(limit int, offset int, order, role string, dbHandle sqlQuerier) ([]User, error) { users := make([]User, 0, limit) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUsersQuery(order, role) var args []any if role == "" { args = append(args, limit, offset) } else { args = append(args, role, limit, offset) } rows, err := dbHandle.QueryContext(ctx, q, args...) if err != nil { return users, err } defer rows.Close() for rows.Next() { u, err := getUserFromDbRow(rows) if err != nil { return users, err } users = append(users, u) } err = rows.Err() if err != nil { return users, err } users, err = getUsersWithVirtualFolders(ctx, users, dbHandle) if err != nil { return users, err } users, err = getUsersWithGroups(ctx, users, dbHandle) if err != nil { return users, err } for idx := range users { users[idx].PrepareForRendering() } return users, nil } func sqlCommonGetDefenderHosts(from int64, limit int, dbHandle sqlQuerier) ([]DefenderEntry, error) { hosts := make([]DefenderEntry, 0, 100) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDefenderHostsQuery() rows, err := dbHandle.QueryContext(ctx, q, from, limit) if err != nil { providerLog(logger.LevelError, "unable to get defender hosts: %v", err) return hosts, err } defer rows.Close() var idForScores []int64 for rows.Next() { var banTime sql.NullInt64 host := DefenderEntry{} err = rows.Scan(&host.ID, &host.IP, &banTime) if err != nil { providerLog(logger.LevelError, "unable to scan defender host row: %v", err) return hosts, err } var hostBanTime time.Time if banTime.Valid && banTime.Int64 > 0 { hostBanTime = util.GetTimeFromMsecSinceEpoch(banTime.Int64) } if hostBanTime.IsZero() || hostBanTime.Before(time.Now()) { idForScores = append(idForScores, host.ID) } else { host.BanTime = hostBanTime } hosts = append(hosts, host) } err = rows.Err() if err != nil { providerLog(logger.LevelError, "unable to iterate over defender host rows: %v", err) return hosts, err } return getDefenderHostsWithScores(ctx, hosts, from, idForScores, dbHandle) } func sqlCommonIsDefenderHostBanned(ip string, dbHandle sqlQuerier) (DefenderEntry, error) { var host DefenderEntry ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDefenderIsHostBannedQuery() row := dbHandle.QueryRowContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now())) err := row.Scan(&host.ID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return host, util.NewRecordNotFoundError("host not found") } providerLog(logger.LevelError, "unable to check ban status for host %q: %v", ip, err) return host, err } return host, nil } func sqlCommonGetDefenderHostByIP(ip string, from int64, dbHandle sqlQuerier) (DefenderEntry, error) { var host DefenderEntry ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDefenderHostQuery() row := dbHandle.QueryRowContext(ctx, q, ip, from) var banTime sql.NullInt64 err := row.Scan(&host.ID, &host.IP, &banTime) if err != nil { if errors.Is(err, sql.ErrNoRows) { return host, util.NewRecordNotFoundError("host not found") } providerLog(logger.LevelError, "unable to get host for ip %q: %v", ip, err) return host, err } if banTime.Valid && banTime.Int64 > 0 { hostBanTime := util.GetTimeFromMsecSinceEpoch(banTime.Int64) if !hostBanTime.IsZero() && hostBanTime.After(time.Now()) { host.BanTime = hostBanTime return host, nil } } hosts, err := getDefenderHostsWithScores(ctx, []DefenderEntry{host}, from, []int64{host.ID}, dbHandle) if err != nil { return host, err } if len(hosts) == 0 { return host, util.NewRecordNotFoundError("host not found") } return hosts[0], nil } func sqlCommonDefenderIncrementBanTime(ip string, minutesToAdd int, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDefenderIncrementBanTimeQuery() _, err := dbHandle.ExecContext(ctx, q, minutesToAdd*60000, ip) if err == nil { providerLog(logger.LevelDebug, "ban time updated for ip %q, increment (minutes): %v", ip, minutesToAdd) } else { providerLog(logger.LevelError, "error updating ban time for ip %q: %v", ip, err) } return err } func sqlCommonSetDefenderBanTime(ip string, banTime int64, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDefenderSetBanTimeQuery() _, err := dbHandle.ExecContext(ctx, q, banTime, ip) if err == nil { providerLog(logger.LevelDebug, "ip %q banned until %v", ip, util.GetTimeFromMsecSinceEpoch(banTime)) } else { providerLog(logger.LevelError, "error setting ban time for ip %q: %v", ip, err) } return err } func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteDefenderHostQuery() res, err := dbHandle.ExecContext(ctx, q, ip) if err != nil { providerLog(logger.LevelError, "unable to delete defender host %q: %v", ip, err) return err } return sqlCommonRequireRowAffected(res) } func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { if err := sqlCommonAddDefenderHost(ctx, ip, tx); err != nil { return err } return sqlCommonAddDefenderEvent(ctx, ip, score, tx) }) } func sqlCommonDefenderCleanup(from int64, dbHandler *sql.DB) error { if err := sqlCommonCleanupDefenderEvents(from, dbHandler); err != nil { return err } return sqlCommonCleanupDefenderHosts(from, dbHandler) } func sqlCommonAddDefenderHost(ctx context.Context, ip string, tx *sql.Tx) error { q := getAddDefenderHostQuery() _, err := tx.ExecContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now())) if err != nil { providerLog(logger.LevelError, "unable to add defender host %q: %v", ip, err) } return err } func sqlCommonAddDefenderEvent(ctx context.Context, ip string, score int, tx *sql.Tx) error { q := getAddDefenderEventQuery() _, err := tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), score, ip) if err != nil { providerLog(logger.LevelError, "unable to add defender event for %q: %v", ip, err) } return err } func sqlCommonCleanupDefenderHosts(from int64, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDefenderHostsCleanupQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), from) if err != nil { providerLog(logger.LevelError, "unable to cleanup defender hosts: %v", err) } return err } func sqlCommonCleanupDefenderEvents(from int64, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDefenderEventsCleanupQuery() _, err := dbHandle.ExecContext(ctx, q, from) if err != nil { providerLog(logger.LevelError, "unable to cleanup defender events: %v", err) } return err } func getShareFromDbRow(row sqlScanner) (Share, error) { var share Share var description, password sql.NullString var allowFrom, paths []byte err := row.Scan(&share.ShareID, &share.Name, &description, &share.Scope, &paths, &share.Username, &share.CreatedAt, &share.UpdatedAt, &share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens, &share.UsedTokens, &allowFrom) if err != nil { if errors.Is(err, sql.ErrNoRows) { return share, util.NewRecordNotFoundError(err.Error()) } return share, err } var list []string err = json.Unmarshal(paths, &list) if err != nil { return share, err } share.Paths = list if description.Valid { share.Description = description.String } if password.Valid { share.Password = password.String } list = nil err = json.Unmarshal(allowFrom, &list) if err == nil { share.AllowFrom = list } return share, nil } func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) { var apiKey APIKey var userID, adminID sql.NullInt64 var description sql.NullString err := row.Scan(&apiKey.KeyID, &apiKey.Name, &apiKey.Key, &apiKey.Scope, &apiKey.CreatedAt, &apiKey.UpdatedAt, &apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return apiKey, util.NewRecordNotFoundError(err.Error()) } return apiKey, err } if userID.Valid { apiKey.userID = userID.Int64 } if adminID.Valid { apiKey.adminID = adminID.Int64 } if description.Valid { apiKey.Description = description.String } return apiKey, nil } func getAdminFromDbRow(row sqlScanner) (Admin, error) { var admin Admin var email, additionalInfo, description, role sql.NullString var permissions, filters []byte err := row.Scan(&admin.ID, &admin.Username, &admin.Password, &admin.Status, &email, &permissions, &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin, &role) if err != nil { if errors.Is(err, sql.ErrNoRows) { return admin, util.NewRecordNotFoundError(err.Error()) } return admin, err } var perms []string err = json.Unmarshal(permissions, &perms) if err != nil { return admin, err } admin.Permissions = perms if email.Valid { admin.Email = email.String } var adminFilters AdminFilters err = json.Unmarshal(filters, &adminFilters) if err == nil { admin.Filters = adminFilters } if additionalInfo.Valid { admin.AdditionalInfo = additionalInfo.String } if description.Valid { admin.Description = description.String } if role.Valid { admin.Role = role.String } admin.SetEmptySecretsIfNil() return admin, nil } func getEventActionFromDbRow(row sqlScanner) (BaseEventAction, error) { var action BaseEventAction var description sql.NullString var options []byte err := row.Scan(&action.ID, &action.Name, &description, &action.Type, &options) if err != nil { if errors.Is(err, sql.ErrNoRows) { return action, util.NewRecordNotFoundError(err.Error()) } return action, err } if description.Valid { action.Description = description.String } var actionOptions BaseEventActionOptions err = json.Unmarshal(options, &actionOptions) if err == nil { action.Options = actionOptions } return action, nil } func getEventRuleFromDbRow(row sqlScanner) (EventRule, error) { var rule EventRule var description sql.NullString var conditions []byte err := row.Scan(&rule.ID, &rule.Name, &description, &rule.CreatedAt, &rule.UpdatedAt, &rule.Trigger, &conditions, &rule.DeletedAt, &rule.Status) if err != nil { if errors.Is(err, sql.ErrNoRows) { return rule, util.NewRecordNotFoundError(err.Error()) } return rule, err } var ruleConditions EventConditions err = json.Unmarshal(conditions, &ruleConditions) if err == nil { rule.Conditions = ruleConditions } if description.Valid { rule.Description = description.String } return rule, nil } func getIPListEntryFromDbRow(row sqlScanner) (IPListEntry, error) { var entry IPListEntry var description sql.NullString err := row.Scan(&entry.Type, &entry.IPOrNet, &entry.Mode, &entry.Protocols, &description, &entry.CreatedAt, &entry.UpdatedAt, &entry.DeletedAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { return entry, util.NewRecordNotFoundError(err.Error()) } return entry, err } if description.Valid { entry.Description = description.String } return entry, err } func getRoleFromDbRow(row sqlScanner) (Role, error) { var role Role var description sql.NullString err := row.Scan(&role.ID, &role.Name, &description, &role.CreatedAt, &role.UpdatedAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { return role, util.NewRecordNotFoundError(err.Error()) } return role, err } if description.Valid { role.Description = description.String } return role, nil } func getGroupFromDbRow(row sqlScanner) (Group, error) { var group Group var description sql.NullString var userSettings []byte err := row.Scan(&group.ID, &group.Name, &description, &group.CreatedAt, &group.UpdatedAt, &userSettings) if err != nil { if errors.Is(err, sql.ErrNoRows) { return group, util.NewRecordNotFoundError(err.Error()) } return group, err } if description.Valid { group.Description = description.String } var settings GroupUserSettings err = json.Unmarshal(userSettings, &settings) if err == nil { group.UserSettings = settings } return group, nil } func getUserFromDbRow(row sqlScanner) (User, error) { var user User var password sql.NullString var permissions, publicKey, filters, fsConfig []byte var additionalInfo, description, email, role sql.NullString err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate, &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig, &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt, &user.UploadDataTransfer, &user.DownloadDataTransfer, &user.TotalDataTransfer, &user.UsedUploadDataTransfer, &user.UsedDownloadDataTransfer, &user.DeletedAt, &user.FirstDownload, &user.FirstUpload, &role, &user.LastPasswordChange) if err != nil { if errors.Is(err, sql.ErrNoRows) { return user, util.NewRecordNotFoundError(err.Error()) } return user, err } if password.Valid { user.Password = password.String } perms := make(map[string][]string) err = json.Unmarshal(permissions, &perms) if err != nil { providerLog(logger.LevelError, "unable to deserialize permissions for user %q: %v", user.Username, err) return user, fmt.Errorf("unable to deserialize permissions for user %q: %v", user.Username, err) } user.Permissions = perms // we can have a empty string or an invalid json in null string // so we do a relaxed test if the field is optional, for example we // populate public keys only if unmarshal does not return an error var pKeys []string err = json.Unmarshal(publicKey, &pKeys) if err == nil { user.PublicKeys = pKeys } var userFilters UserFilters err = json.Unmarshal(filters, &userFilters) if err == nil { user.Filters = userFilters } var fs vfs.Filesystem err = json.Unmarshal(fsConfig, &fs) if err == nil { user.FsConfig = fs } if additionalInfo.Valid { user.AdditionalInfo = additionalInfo.String } if description.Valid { user.Description = description.String } if email.Valid { user.Email = email.String } if role.Valid { user.Role = role.String } user.SetEmptySecretsIfNil() return user, nil } func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { var folder vfs.BaseVirtualFolder q := getFolderByNameQuery() row := dbHandle.QueryRowContext(ctx, q, name) var mappedPath, description sql.NullString var fsConfig []byte err := row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig) if err != nil { if errors.Is(err, sql.ErrNoRows) { return folder, util.NewRecordNotFoundError(err.Error()) } return folder, err } if mappedPath.Valid { folder.MappedPath = mappedPath.String } if description.Valid { folder.Description = description.String } var fs vfs.Filesystem err = json.Unmarshal(fsConfig, &fs) if err == nil { folder.FsConfig = fs } return folder, err } func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { folder, err := sqlCommonGetFolder(ctx, name, dbHandle) if err != nil { return folder, err } folders, err := getVirtualFoldersWithUsers([]vfs.BaseVirtualFolder{folder}, dbHandle) if err != nil { return folder, err } if len(folders) != 1 { return folder, fmt.Errorf("unable to associate users with folder %q", name) } folders, err = getVirtualFoldersWithGroups([]vfs.BaseVirtualFolder{folders[0]}, dbHandle) if err != nil { return folder, err } if len(folders) != 1 { return folder, fmt.Errorf("unable to associate groups with folder %q", name) } return folders[0], nil } func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { err := ValidateFolder(folder) if err != nil { return err } fsConfig, err := json.Marshal(folder.FsConfig) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddFolderQuery() _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles, folder.LastQuotaUpdate, folder.Name, folder.Description, fsConfig) return err } func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { err := ValidateFolder(folder) if err != nil { return err } fsConfig, err := json.Marshal(folder.FsConfig) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateFolderQuery() res, err := dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, fsConfig, folder.Name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteFolderQuery() res, err := dbHandle.ExecContext(ctx, q, folder.Name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { folders := make([]vfs.BaseVirtualFolder, 0, 50) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getDumpFoldersQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return folders, err } defer rows.Close() for rows.Next() { var folder vfs.BaseVirtualFolder var mappedPath, description sql.NullString var fsConfig []byte err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig) if err != nil { return folders, err } if mappedPath.Valid { folder.MappedPath = mappedPath.String } if description.Valid { folder.Description = description.String } var fs vfs.Filesystem err = json.Unmarshal(fsConfig, &fs) if err == nil { folder.FsConfig = fs } folders = append(folders, folder) } return folders, rows.Err() } func sqlCommonGetFolders(limit, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { folders := make([]vfs.BaseVirtualFolder, 0, limit) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getFoldersQuery(order, minimal) rows, err := dbHandle.QueryContext(ctx, q, limit, offset) if err != nil { return folders, err } defer rows.Close() for rows.Next() { var folder vfs.BaseVirtualFolder if minimal { err = rows.Scan(&folder.ID, &folder.Name) if err != nil { return folders, err } } else { var mappedPath, description sql.NullString var fsConfig []byte err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig) if err != nil { return folders, err } if mappedPath.Valid { folder.MappedPath = mappedPath.String } if description.Valid { folder.Description = description.String } var fs vfs.Filesystem err = json.Unmarshal(fsConfig, &fs) if err == nil { folder.FsConfig = fs } } folder.PrepareForRendering() folders = append(folders, folder) } err = rows.Err() if err != nil { return folders, err } if minimal { return folders, nil } folders, err = getVirtualFoldersWithUsers(folders, dbHandle) if err != nil { return folders, err } return getVirtualFoldersWithGroups(folders, dbHandle) } func sqlCommonClearUserFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error { q := getClearUserFolderMappingQuery() _, err := dbHandle.ExecContext(ctx, q, user.Username) return err } func sqlCommonClearGroupFolderMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error { q := getClearGroupFolderMappingQuery() _, err := dbHandle.ExecContext(ctx, q, group.Name) return err } func sqlCommonClearUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error { q := getClearUserGroupMappingQuery() _, err := dbHandle.ExecContext(ctx, q, user.Username) return err } func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, sortOrder int, dbHandle sqlQuerier) error { q := getAddUserFolderMappingQuery() _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username, sortOrder) return err } func sqlCommonClearAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQuerier) error { q := getClearAdminGroupMappingQuery() _, err := dbHandle.ExecContext(ctx, q, admin.Username) return err } func sqlCommonAddGroupFolderMapping(ctx context.Context, group *Group, folder *vfs.VirtualFolder, sortOrder int, dbHandle sqlQuerier, ) error { q := getAddGroupFolderMappingQuery() _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, group.Name, sortOrder) return err } func sqlCommonAddUserGroupMapping(ctx context.Context, username, groupName string, groupType, sortOrder int, dbHandle sqlQuerier) error { q := getAddUserGroupMappingQuery() _, err := dbHandle.ExecContext(ctx, q, username, groupName, groupType, sortOrder) return err } func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName string, mappingOptions AdminGroupMappingOptions, sortOrder int, dbHandle sqlQuerier, ) error { options, err := json.Marshal(mappingOptions) if err != nil { return err } q := getAddAdminGroupMappingQuery() _, err = dbHandle.ExecContext(ctx, q, username, groupName, options, sortOrder) return err } func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error { err := sqlCommonClearGroupFolderMapping(ctx, group, dbHandle) if err != nil { return err } for idx := range group.VirtualFolders { vfolder := &group.VirtualFolders[idx] err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, idx, dbHandle) if err != nil { return err } } return err } func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error { err := sqlCommonClearUserFolderMapping(ctx, user, dbHandle) if err != nil { return err } for idx := range user.VirtualFolders { vfolder := &user.VirtualFolders[idx] err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, idx, dbHandle) if err != nil { return err } } return err } func generateUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error { err := sqlCommonClearUserGroupMapping(ctx, user, dbHandle) if err != nil { return err } for idx, group := range user.Groups { err = sqlCommonAddUserGroupMapping(ctx, user.Username, group.Name, group.Type, idx, dbHandle) if err != nil { return err } } return err } func generateAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQuerier) error { err := sqlCommonClearAdminGroupMapping(ctx, admin, dbHandle) if err != nil { return err } for idx, group := range admin.Groups { err = sqlCommonAddAdminGroupMapping(ctx, admin.Username, group.Name, group.Options, idx, dbHandle) if err != nil { return err } } return err } func getDefenderHostsWithScores(ctx context.Context, hosts []DefenderEntry, from int64, idForScores []int64, dbHandle sqlQuerier) ( []DefenderEntry, error, ) { if len(idForScores) == 0 { return hosts, nil } hostsWithScores := make(map[int64]int) q := getDefenderEventsQuery(idForScores) rows, err := dbHandle.QueryContext(ctx, q, from) if err != nil { providerLog(logger.LevelError, "unable to get score for hosts with id %+v: %v", idForScores, err) return nil, err } defer rows.Close() for rows.Next() { var hostID int64 var score int err = rows.Scan(&hostID, &score) if err != nil { providerLog(logger.LevelError, "error scanning host score row: %v", err) return hosts, err } if score > 0 { hostsWithScores[hostID] = score } } err = rows.Err() if err != nil { return hosts, err } result := make([]DefenderEntry, 0, len(hosts)) for idx := range hosts { hosts[idx].Score = hostsWithScores[hosts[idx].ID] if hosts[idx].Score > 0 || !hosts[idx].BanTime.IsZero() { result = append(result, hosts[idx]) } } return result, nil } func getAdminWithGroups(ctx context.Context, admin Admin, dbHandle sqlQuerier) (Admin, error) { admins, err := getAdminsWithGroups(ctx, []Admin{admin}, dbHandle) if err != nil { return admin, err } if len(admins) == 0 { return admin, errSQLGroupsAssociation } return admins[0], err } func getAdminsWithGroups(ctx context.Context, admins []Admin, dbHandle sqlQuerier) ([]Admin, error) { if len(admins) == 0 { return admins, nil } adminsGroups := make(map[int64][]AdminGroupMapping) q := getRelatedGroupsForAdminsQuery(admins) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var group AdminGroupMapping var adminID int64 var options []byte err = rows.Scan(&group.Name, &options, &adminID) if err != nil { return admins, err } err = json.Unmarshal(options, &group.Options) if err != nil { return admins, err } adminsGroups[adminID] = append(adminsGroups[adminID], group) } err = rows.Err() if err != nil { return admins, err } if len(adminsGroups) == 0 { return admins, err } for idx := range admins { ref := &admins[idx] ref.Groups = adminsGroups[ref.ID] } return admins, err } func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) { users, err := getUsersWithVirtualFolders(ctx, []User{user}, dbHandle) if err != nil { return user, err } if len(users) == 0 { return user, errSQLFoldersAssociation } return users[0], err } func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) { if len(users) == 0 { return users, nil } usersVirtualFolders := make(map[int64][]vfs.VirtualFolder) q := getRelatedFoldersForUsersQuery(users) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var folder vfs.VirtualFolder var userID int64 var mappedPath, description sql.NullString var fsConfig []byte err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig, &description) if err != nil { return users, err } if mappedPath.Valid { folder.MappedPath = mappedPath.String } if description.Valid { folder.Description = description.String } var fs vfs.Filesystem err = json.Unmarshal(fsConfig, &fs) if err == nil { folder.FsConfig = fs } usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder) } err = rows.Err() if err != nil { return users, err } if len(usersVirtualFolders) == 0 { return users, err } for idx := range users { ref := &users[idx] ref.VirtualFolders = usersVirtualFolders[ref.ID] } return users, err } func getUserWithGroups(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) { users, err := getUsersWithGroups(ctx, []User{user}, dbHandle) if err != nil { return user, err } if len(users) == 0 { return user, errSQLGroupsAssociation } return users[0], err } func getUsersWithGroups(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) { if len(users) == 0 { return users, nil } usersGroups := make(map[int64][]sdk.GroupMapping) q := getRelatedGroupsForUsersQuery(users) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var group sdk.GroupMapping var userID int64 err = rows.Scan(&group.Name, &group.Type, &userID) if err != nil { return users, err } usersGroups[userID] = append(usersGroups[userID], group) } err = rows.Err() if err != nil { return users, err } if len(usersGroups) == 0 { return users, err } for idx := range users { ref := &users[idx] ref.Groups = usersGroups[ref.ID] } return users, err } func getGroupWithUsers(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) { groups, err := getGroupsWithUsers(ctx, []Group{group}, dbHandle) if err != nil { return group, err } if len(groups) == 0 { return group, errSQLUsersAssociation } return groups[0], err } func getRoleWithUsers(ctx context.Context, role Role, dbHandle sqlQuerier) (Role, error) { roles, err := getRolesWithUsers(ctx, []Role{role}, dbHandle) if err != nil { return role, err } if len(roles) == 0 { return role, errors.New("unable to associate users with role") } return roles[0], err } func getRoleWithAdmins(ctx context.Context, role Role, dbHandle sqlQuerier) (Role, error) { roles, err := getRolesWithAdmins(ctx, []Role{role}, dbHandle) if err != nil { return role, err } if len(roles) == 0 { return role, errors.New("unable to associate admins with role") } return roles[0], err } func getGroupWithAdmins(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) { groups, err := getGroupsWithAdmins(ctx, []Group{group}, dbHandle) if err != nil { return group, err } if len(groups) == 0 { return group, errSQLUsersAssociation } return groups[0], err } func getGroupWithVirtualFolders(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) { groups, err := getGroupsWithVirtualFolders(ctx, []Group{group}, dbHandle) if err != nil { return group, err } if len(groups) == 0 { return group, errSQLFoldersAssociation } return groups[0], err } func getGroupsWithVirtualFolders(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) { if len(groups) == 0 { return groups, nil } q := getRelatedFoldersForGroupsQuery(groups) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() groupsVirtualFolders := make(map[int64][]vfs.VirtualFolder) for rows.Next() { var groupID int64 var folder vfs.VirtualFolder var mappedPath, description sql.NullString var fsConfig []byte err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &groupID, &fsConfig, &description) if err != nil { return groups, err } if mappedPath.Valid { folder.MappedPath = mappedPath.String } if description.Valid { folder.Description = description.String } var fs vfs.Filesystem err = json.Unmarshal(fsConfig, &fs) if err == nil { folder.FsConfig = fs } groupsVirtualFolders[groupID] = append(groupsVirtualFolders[groupID], folder) } err = rows.Err() if err != nil { return groups, err } if len(groupsVirtualFolders) == 0 { return groups, err } for idx := range groups { ref := &groups[idx] ref.VirtualFolders = groupsVirtualFolders[ref.ID] } return groups, err } func getGroupsWithUsers(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) { if len(groups) == 0 { return groups, nil } q := getRelatedUsersForGroupsQuery(groups) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() groupsUsers := make(map[int64][]string) for rows.Next() { var username string var groupID int64 err = rows.Scan(&groupID, &username) if err != nil { return groups, err } groupsUsers[groupID] = append(groupsUsers[groupID], username) } err = rows.Err() if err != nil { return groups, err } if len(groupsUsers) == 0 { return groups, err } for idx := range groups { ref := &groups[idx] ref.Users = groupsUsers[ref.ID] } return groups, err } func getRolesWithUsers(ctx context.Context, roles []Role, dbHandle sqlQuerier) ([]Role, error) { if len(roles) == 0 { return roles, nil } rows, err := dbHandle.QueryContext(ctx, getUsersWithRolesQuery(roles)) if err != nil { return nil, err } defer rows.Close() rolesUsers := make(map[int64][]string) for rows.Next() { var roleID int64 var username string err = rows.Scan(&roleID, &username) if err != nil { return roles, err } rolesUsers[roleID] = append(rolesUsers[roleID], username) } err = rows.Err() if err != nil { return roles, err } if len(rolesUsers) > 0 { for idx := range roles { ref := &roles[idx] ref.Users = rolesUsers[ref.ID] } } return roles, nil } func getRolesWithAdmins(ctx context.Context, roles []Role, dbHandle sqlQuerier) ([]Role, error) { if len(roles) == 0 { return roles, nil } rows, err := dbHandle.QueryContext(ctx, getAdminsWithRolesQuery(roles)) if err != nil { return nil, err } defer rows.Close() rolesAdmins := make(map[int64][]string) for rows.Next() { var roleID int64 var username string err = rows.Scan(&roleID, &username) if err != nil { return roles, err } rolesAdmins[roleID] = append(rolesAdmins[roleID], username) } if err = rows.Err(); err != nil { return roles, err } if len(rolesAdmins) > 0 { for idx := range roles { ref := &roles[idx] ref.Admins = rolesAdmins[ref.ID] } } return roles, nil } func getGroupsWithAdmins(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) { if len(groups) == 0 { return groups, nil } q := getRelatedAdminsForGroupsQuery(groups) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() groupsAdmins := make(map[int64][]string) for rows.Next() { var groupID int64 var username string err = rows.Scan(&groupID, &username) if err != nil { return groups, err } groupsAdmins[groupID] = append(groupsAdmins[groupID], username) } err = rows.Err() if err != nil { return groups, err } if len(groupsAdmins) > 0 { for idx := range groups { ref := &groups[idx] ref.Admins = groupsAdmins[ref.ID] } } return groups, nil } func getVirtualFoldersWithGroups(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { if len(folders) == 0 { return folders, nil } vFoldersGroups := make(map[int64][]string) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getRelatedGroupsForFoldersQuery(folders) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var name string var folderID int64 err = rows.Scan(&folderID, &name) if err != nil { return folders, err } vFoldersGroups[folderID] = append(vFoldersGroups[folderID], name) } err = rows.Err() if err != nil { return folders, err } if len(vFoldersGroups) == 0 { return folders, err } for idx := range folders { ref := &folders[idx] ref.Groups = vFoldersGroups[ref.ID] } return folders, err } func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { if len(folders) == 0 { return folders, nil } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getRelatedUsersForFoldersQuery(folders) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() vFoldersUsers := make(map[int64][]string) for rows.Next() { var username string var folderID int64 err = rows.Scan(&folderID, &username) if err != nil { return folders, err } vFoldersUsers[folderID] = append(vFoldersUsers[folderID], username) } err = rows.Err() if err != nil { return folders, err } if len(vFoldersUsers) == 0 { return folders, err } for idx := range folders { ref := &folders[idx] ref.Users = vFoldersUsers[ref.ID] } return folders, err } func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateFolderQuotaQuery(reset) _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name) if err == nil { providerLog(logger.LevelDebug, "quota updated for folder %q, files increment: %d size increment: %d is reset? %t", name, filesAdd, sizeAdd, reset) } else { providerLog(logger.LevelWarn, "error updating quota for folder %q: %v", name, err) } return err } func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int64, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getQuotaFolderQuery() var usedFiles int var usedSize int64 err := dbHandle.QueryRowContext(ctx, q, mappedPath).Scan(&usedSize, &usedFiles) if err != nil { providerLog(logger.LevelError, "error getting quota for folder: %v, error: %v", mappedPath, err) return 0, 0, err } return usedFiles, usedSize, err } func getAPIKeyWithRelatedFields(ctx context.Context, apiKey APIKey, dbHandle sqlQuerier) (APIKey, error) { var apiKeys []APIKey var err error scope := APIKeyScopeAdmin if apiKey.userID > 0 { scope = APIKeyScopeUser } apiKeys, err = getRelatedValuesForAPIKeys(ctx, []APIKey{apiKey}, dbHandle, scope) if err != nil { return apiKey, err } if len(apiKeys) > 0 { apiKey = apiKeys[0] } return apiKey, nil } func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle sqlQuerier, scope APIKeyScope) ([]APIKey, error) { if len(apiKeys) == 0 { return apiKeys, nil } values := make(map[int64]string) var q string if scope == APIKeyScopeUser { q = getRelatedUsersForAPIKeysQuery(apiKeys) } else { q = getRelatedAdminsForAPIKeysQuery(apiKeys) } rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var valueID int64 var valueName string err = rows.Scan(&valueID, &valueName) if err != nil { return apiKeys, err } values[valueID] = valueName } err = rows.Err() if err != nil { return apiKeys, err } if len(values) == 0 { return apiKeys, nil } for idx := range apiKeys { ref := &apiKeys[idx] if scope == APIKeyScopeUser { ref.User = values[ref.userID] } else { ref.Admin = values[ref.adminID] } } return apiKeys, nil } func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) { var userID, adminID sql.NullInt64 if apiKey.User != "" { u, err := provider.userExists(apiKey.User, "") if err != nil { return userID, adminID, util.NewGenericError(fmt.Sprintf("unable to validate user %v", apiKey.User)) } userID.Valid = true userID.Int64 = u.ID } if apiKey.Admin != "" { a, err := provider.adminExists(apiKey.Admin) if err != nil { return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate admin %v", apiKey.Admin)) } adminID.Valid = true adminID.Int64 = a.ID } return userID, adminID, nil } func sqlCommonAddSession(session Session, dbHandle *sql.DB) error { if err := session.validate(); err != nil { return err } data, err := json.Marshal(session.Data) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddSessionQuery() _, err = dbHandle.ExecContext(ctx, q, session.Key, data, session.Type, session.Timestamp) return err } func sqlCommonGetSession(key string, sessionType SessionType, dbHandle sqlQuerier) (Session, error) { var session Session ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getSessionQuery() var data []byte // type hint, some driver will use string instead of []byte if the type is any err := dbHandle.QueryRowContext(ctx, q, key, sessionType).Scan(&session.Key, &data, &session.Type, &session.Timestamp) if err != nil { if errors.Is(err, sql.ErrNoRows) { return session, util.NewRecordNotFoundError(err.Error()) } return session, err } session.Data = data return session, nil } func sqlCommonDeleteSession(key string, sessionType SessionType, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteSessionQuery() res, err := dbHandle.ExecContext(ctx, q, key, sessionType) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonCleanupSessions(sessionType SessionType, before int64, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getCleanupSessionsQuery() _, err := dbHandle.ExecContext(ctx, q, sessionType, before) return err } func getActionsWithRuleNames(ctx context.Context, actions []BaseEventAction, dbHandle sqlQuerier, ) ([]BaseEventAction, error) { if len(actions) == 0 { return actions, nil } q := getRelatedRulesForActionsQuery(actions) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() actionsRules := make(map[int64][]string) for rows.Next() { var name string var actionID int64 if err = rows.Scan(&actionID, &name); err != nil { return nil, err } actionsRules[actionID] = append(actionsRules[actionID], name) } err = rows.Err() if err != nil { return nil, err } if len(actionsRules) == 0 { return actions, nil } for idx := range actions { ref := &actions[idx] ref.Rules = actionsRules[ref.ID] } return actions, nil } func getRulesWithActions(ctx context.Context, rules []EventRule, dbHandle sqlQuerier) ([]EventRule, error) { if len(rules) == 0 { return rules, nil } rulesActions := make(map[int64][]EventAction) q := getRelatedActionsForRulesQuery(rules) rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var action EventAction var ruleID int64 var description sql.NullString var baseOptions, options []byte err = rows.Scan(&action.ID, &action.Name, &description, &action.Type, &baseOptions, &options, &action.Order, &ruleID) if err != nil { return rules, err } if len(baseOptions) > 0 { err = json.Unmarshal(baseOptions, &action.BaseEventAction.Options) if err != nil { return rules, err } } if len(options) > 0 { err = json.Unmarshal(options, &action.Options) if err != nil { return rules, err } } action.BaseEventAction.Options.SetEmptySecretsIfNil() rulesActions[ruleID] = append(rulesActions[ruleID], action) } err = rows.Err() if err != nil { return rules, err } if len(rulesActions) == 0 { return rules, nil } for idx := range rules { ref := &rules[idx] ref.Actions = rulesActions[ref.ID] } return rules, nil } func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHandle sqlQuerier) error { q := getClearRuleActionMappingQuery() _, err := dbHandle.ExecContext(ctx, q, rule.Name) if err != nil { return err } for _, action := range rule.Actions { options, err := json.Marshal(action.Options) if err != nil { return err } q = getAddRuleActionMappingQuery() _, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, options) if err != nil { return err } } return nil } func sqlCommonGetEventActionByName(name string, dbHandle sqlQuerier) (BaseEventAction, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getEventActionByNameQuery() row := dbHandle.QueryRowContext(ctx, q, name) action, err := getEventActionFromDbRow(row) if err != nil { return action, err } actions, err := getActionsWithRuleNames(ctx, []BaseEventAction{action}, dbHandle) if err != nil { return action, err } if len(actions) != 1 { return action, fmt.Errorf("unable to associate rules with action %q", name) } return actions[0], nil } func sqlCommonDumpEventActions(dbHandle sqlQuerier) ([]BaseEventAction, error) { actions := make([]BaseEventAction, 0, 10) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getDumpEventActionsQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return actions, err } defer rows.Close() for rows.Next() { action, err := getEventActionFromDbRow(rows) if err != nil { return actions, err } actions = append(actions, action) } return actions, rows.Err() } func sqlCommonGetEventActions(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier, ) ([]BaseEventAction, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getEventsActionsQuery(order, minimal) actions := make([]BaseEventAction, 0, limit) rows, err := dbHandle.QueryContext(ctx, q, limit, offset) if err != nil { return actions, err } defer rows.Close() for rows.Next() { var action BaseEventAction if minimal { err = rows.Scan(&action.ID, &action.Name) } else { action, err = getEventActionFromDbRow(rows) } if err != nil { return actions, err } actions = append(actions, action) } err = rows.Err() if err != nil { return nil, err } if minimal { return actions, nil } actions, err = getActionsWithRuleNames(ctx, actions, dbHandle) if err != nil { return nil, err } for idx := range actions { actions[idx].PrepareForRendering() } return actions, nil } func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error { if err := action.validate(); err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddEventActionQuery() options, err := json.Marshal(action.Options) if err != nil { return err } _, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, options) return err } func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error { if err := action.validate(); err != nil { return err } options, err := json.Marshal(action.Options) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateEventActionQuery() res, err := tx.ExecContext(ctx, q, action.Description, action.Type, options, action.Name) if err != nil { return err } if err := sqlCommonRequireRowAffected(res); err != nil { return err } q = getUpdateRulesTimestampQuery() _, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.Name) return err }) } func sqlCommonDeleteEventAction(action BaseEventAction, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteEventActionQuery() res, err := dbHandle.ExecContext(ctx, q, action.Name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonGetEventRuleByName(name string, dbHandle sqlQuerier) (EventRule, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getEventRulesByNameQuery() row := dbHandle.QueryRowContext(ctx, q, name) rule, err := getEventRuleFromDbRow(row) if err != nil { return rule, err } rules, err := getRulesWithActions(ctx, []EventRule{rule}, dbHandle) if err != nil { return rule, err } if len(rules) != 1 { return rule, fmt.Errorf("unable to associate rule %q with actions", name) } return rules[0], nil } func sqlCommonDumpEventRules(dbHandle sqlQuerier) ([]EventRule, error) { rules := make([]EventRule, 0, 10) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getDumpEventRulesQuery() rows, err := dbHandle.QueryContext(ctx, q) if err != nil { return rules, err } defer rows.Close() for rows.Next() { rule, err := getEventRuleFromDbRow(rows) if err != nil { return rules, err } rules = append(rules, rule) } err = rows.Err() if err != nil { return rules, err } return getRulesWithActions(ctx, rules, dbHandle) } func sqlCommonGetRecentlyUpdatedRules(after int64, dbHandle sqlQuerier) ([]EventRule, error) { rules := make([]EventRule, 0, 10) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() q := getRecentlyUpdatedRulesQuery() rows, err := dbHandle.QueryContext(ctx, q, after) if err != nil { return rules, err } defer rows.Close() for rows.Next() { rule, err := getEventRuleFromDbRow(rows) if err != nil { return rules, err } rules = append(rules, rule) } err = rows.Err() if err != nil { return rules, err } return getRulesWithActions(ctx, rules, dbHandle) } func sqlCommonGetEventRules(limit int, offset int, order string, dbHandle sqlQuerier) ([]EventRule, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getEventRulesQuery(order) rules := make([]EventRule, 0, limit) rows, err := dbHandle.QueryContext(ctx, q, limit, offset) if err != nil { return rules, err } defer rows.Close() for rows.Next() { rule, err := getEventRuleFromDbRow(rows) if err != nil { return rules, err } rules = append(rules, rule) } err = rows.Err() if err != nil { return rules, err } rules, err = getRulesWithActions(ctx, rules, dbHandle) if err != nil { return rules, err } for idx := range rules { rules[idx].PrepareForRendering() } return rules, nil } func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error { if err := rule.validate(); err != nil { return err } conditions, err := json.Marshal(rule.Conditions) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { if config.IsShared == 1 { _, err := tx.ExecContext(ctx, getRemoveSoftDeletedRuleQuery(), rule.Name) if err != nil { return err } } q := getAddEventRuleQuery() _, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, conditions, rule.Status) if err != nil { return err } return generateEventRuleActionsMapping(ctx, rule, tx) }) } func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error { if err := rule.validate(); err != nil { return err } conditions, err := json.Marshal(rule.Conditions) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateEventRuleQuery() _, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, conditions, rule.Status, rule.Name) if err != nil { return err } return generateEventRuleActionsMapping(ctx, rule, tx) }) } func sqlCommonDeleteEventRule(rule EventRule, softDelete bool, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { if softDelete { q := getClearRuleActionMappingQuery() _, err := tx.ExecContext(ctx, q, rule.Name) if err != nil { return err } } q := getDeleteEventRuleQuery(softDelete) if softDelete { ts := util.GetTimeAsMsSinceEpoch(time.Now()) res, err := tx.ExecContext(ctx, q, ts, ts, rule.Name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } res, err := tx.ExecContext(ctx, q, rule.Name) if err != nil { return err } if err = sqlCommonRequireRowAffected(res); err != nil { return err } return sqlCommonDeleteTask(rule.Name, tx) }) } func sqlCommonGetTaskByName(name string, dbHandle sqlQuerier) (Task, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() task := Task{ Name: name, } q := getTaskByNameQuery() row := dbHandle.QueryRowContext(ctx, q, name) err := row.Scan(&task.UpdateAt, &task.Version) if err != nil { if errors.Is(err, sql.ErrNoRows) { return task, util.NewRecordNotFoundError(err.Error()) } } return task, err } func sqlCommonAddTask(name string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddTaskQuery() _, err := dbHandle.ExecContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now())) return err } func sqlCommonUpdateTask(name string, version int64, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateTaskQuery() res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name, version) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonUpdateTaskTimestamp(name string, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateTaskTimestampQuery() res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteTask(name string, dbHandle sqlQuerier) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteTaskQuery() _, err := dbHandle.ExecContext(ctx, q, name) return err } func sqlCommonAddNode(dbHandle *sql.DB) error { if err := currentNode.validate(); err != nil { return fmt.Errorf("unable to register cluster node: %w", err) } data, err := json.Marshal(currentNode.Data) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddNodeQuery() _, err = dbHandle.ExecContext(ctx, q, currentNode.Name, data, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now())) if err != nil { return fmt.Errorf("unable to register cluster node: %w", err) } providerLog(logger.LevelInfo, "registered as cluster node %q, port: %d, proto: %s", currentNode.Name, currentNode.Data.Port, currentNode.Data.Proto) return nil } func sqlCommonGetNodeByName(name string, dbHandle *sql.DB) (Node, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() var data []byte var node Node q := getNodeByNameQuery() row := dbHandle.QueryRowContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff))) err := row.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { return node, util.NewRecordNotFoundError(err.Error()) } return node, err } err = json.Unmarshal(data, &node.Data) return node, err } func sqlCommonGetNodes(dbHandle *sql.DB) ([]Node, error) { var nodes []Node ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getNodesQuery() rows, err := dbHandle.QueryContext(ctx, q, currentNode.Name, util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff))) if err != nil { return nodes, err } defer rows.Close() for rows.Next() { var node Node var data []byte err = rows.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt) if err != nil { return nodes, err } err = json.Unmarshal(data, &node.Data) if err != nil { return nodes, err } nodes = append(nodes, node) } return nodes, rows.Err() } func sqlCommonUpdateNodeTimestamp(dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateNodeTimestampQuery() res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), currentNode.Name) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonCleanupNodes(dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getCleanupNodesQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now().Add(10*activeNodeTimeDiff))) return err } func sqlCommonGetConfigs(dbHandle sqlQuerier) (Configs, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() var result Configs var configs []byte q := getConfigsQuery() err := dbHandle.QueryRowContext(ctx, q).Scan(&configs) if err != nil { return result, err } err = json.Unmarshal(configs, &result) return result, err } func sqlCommonSetConfigs(configs *Configs, dbHandle *sql.DB) error { if err := configs.validate(); err != nil { return err } asJSON, err := json.Marshal(configs) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateConfigsQuery() res, err := dbHandle.ExecContext(ctx, q, asJSON) if err != nil { return err } return sqlCommonRequireRowAffected(res) } func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) { var result schemaVersion ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDatabaseVersionQuery() stmt, err := dbHandle.PrepareContext(ctx, q) if err != nil { providerLog(logger.LevelError, "error preparing database query %q: %v", q, err) if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) { logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?") } return result, err } defer stmt.Close() row := stmt.QueryRowContext(ctx) err = row.Scan(&result.Version) return result, err } func sqlCommonRequireRowAffected(res sql.Result) error { affected, err := res.RowsAffected() if err == nil && affected == 0 { return util.NewRecordNotFoundError(sql.ErrNoRows.Error()) } return nil } func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error { q := getUpdateDBVersionQuery() _, err := dbHandle.ExecContext(ctx, q, version) return err } func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int, isUp bool) error { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() conn, err := dbHandle.Conn(ctx) if err != nil { return fmt.Errorf("unable to get connection from pool: %w", err) } defer conn.Close() if err := sqlAcquireLock(conn); err != nil { return err } defer sqlReleaseLock(conn) if newVersion > 0 { currentVersion, err := sqlCommonGetDatabaseVersion(conn, false) if err == nil { if (isUp && currentVersion.Version >= newVersion) || (!isUp && currentVersion.Version <= newVersion) { providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?", currentVersion.Version, newVersion) return nil } } } return sqlCommonExecuteTxOnConn(ctx, conn, func(tx *sql.Tx) error { for _, q := range sqlQueries { if strings.TrimSpace(q) == "" { continue } _, err := tx.ExecContext(ctx, q) if err != nil { return err } } if newVersion == 0 { return nil } return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion) }) } func sqlAcquireLock(dbHandle *sql.Conn) error { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() switch config.Driver { case PGSQLDataProviderName: _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_lock(101,1)`) if err != nil { return fmt.Errorf("unable to get advisory lock: %w", err) } providerLog(logger.LevelInfo, "acquired database lock") case MySQLDataProviderName: var lockResult sql.NullInt64 err := dbHandle.QueryRowContext(ctx, `SELECT GET_LOCK('sftpgo.migration',30)`).Scan(&lockResult) if err != nil { return fmt.Errorf("unable to get lock: %w", err) } if !lockResult.Valid { return errors.New("unable to get lock: null value returned") } if lockResult.Int64 != 1 { return fmt.Errorf("unable to get lock, result: %d", lockResult.Int64) } providerLog(logger.LevelInfo, "acquired database lock") } return nil } func sqlReleaseLock(dbHandle *sql.Conn) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() switch config.Driver { case PGSQLDataProviderName: _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_unlock(101,1)`) if err != nil { providerLog(logger.LevelWarn, "unable to release lock: %v", err) } else { providerLog(logger.LevelInfo, "released database lock") } case MySQLDataProviderName: _, err := dbHandle.ExecContext(ctx, `SELECT RELEASE_LOCK('sftpgo.migration')`) if err != nil { providerLog(logger.LevelWarn, "unable to release lock: %v", err) } else { providerLog(logger.LevelInfo, "released database lock") } } } func sqlCommonExecuteTxOnConn(ctx context.Context, conn *sql.Conn, txFn func(*sql.Tx) error) error { tx, err := conn.BeginTx(ctx, nil) if err != nil { return err } err = txFn(tx) if err != nil { tx.Rollback() //nolint:errcheck return err } return tx.Commit() } func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error { if config.Driver == CockroachDataProviderName { return crdb.ExecuteTx(ctx, dbHandle, nil, txFn) } tx, err := dbHandle.BeginTx(ctx, nil) if err != nil { return err } err = txFn(tx) if err != nil { // we don't change the returned error tx.Rollback() //nolint:errcheck return err } return tx.Commit() } ================================================ FILE: internal/dataprovider/sqlite.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !nosqlite && cgo package dataprovider import ( "context" "crypto/x509" "database/sql" "errors" "fmt" "path/filepath" "strings" "time" "github.com/mattn/go-sqlite3" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( sqliteResetSQL = `DROP TABLE IF EXISTS "{{api_keys}}"; DROP TABLE IF EXISTS "{{users_folders_mapping}}"; DROP TABLE IF EXISTS "{{users_groups_mapping}}"; DROP TABLE IF EXISTS "{{admins_groups_mapping}}"; DROP TABLE IF EXISTS "{{groups_folders_mapping}}"; DROP TABLE IF EXISTS "{{shares_groups_mapping}}"; DROP TABLE IF EXISTS "{{admins}}"; DROP TABLE IF EXISTS "{{folders}}"; DROP TABLE IF EXISTS "{{shares}}"; DROP TABLE IF EXISTS "{{users}}"; DROP TABLE IF EXISTS "{{groups}}"; DROP TABLE IF EXISTS "{{defender_events}}"; DROP TABLE IF EXISTS "{{defender_hosts}}"; DROP TABLE IF EXISTS "{{active_transfers}}"; DROP TABLE IF EXISTS "{{shared_sessions}}"; DROP TABLE IF EXISTS "{{rules_actions_mapping}}"; DROP TABLE IF EXISTS "{{events_rules}}"; DROP TABLE IF EXISTS "{{events_actions}}"; DROP TABLE IF EXISTS "{{tasks}}"; DROP TABLE IF EXISTS "{{roles}}"; DROP TABLE IF EXISTS "{{ip_lists}}"; DROP TABLE IF EXISTS "{{configs}}"; DROP TABLE IF EXISTS "{{schema_version}}"; ` sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY, "version" integer NOT NULL); CREATE TABLE "{{roles}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, "permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL, "role_id" integer NULL REFERENCES "{{roles}}" ("id") ON DELETE NO ACTION, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); CREATE TABLE "{{active_transfers}}" ("id" integer NOT NULL PRIMARY KEY, "connection_id" varchar(100) NOT NULL, "transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL, "folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL, "current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); CREATE TABLE "{{defender_hosts}}" ("id" integer NOT NULL PRIMARY KEY, "ip" varchar(50) NOT NULL UNIQUE, "ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL); CREATE TABLE "{{defender_events}}" ("id" integer NOT NULL PRIMARY KEY, "date_time" bigint NOT NULL, "score" integer NOT NULL, "host_id" integer NOT NULL REFERENCES "{{defender_hosts}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED); CREATE TABLE "{{folders}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "filesystem" text NULL); CREATE TABLE "{{groups}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL); CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL, "type" integer NOT NULL, "data" text NOT NULL, "timestamp" bigint NOT NULL, PRIMARY KEY ("key", "type")); CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL, "expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL, "public_keys" text NULL, "home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL, "filesystem" text NULL, "additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL, "upload_data_transfer" integer NOT NULL, "download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL, "used_upload_data_transfer" bigint NOT NULL, "used_download_data_transfer" bigint NOT NULL, "deleted_at" bigint NOT NULL, "first_download" bigint NOT NULL, "first_upload" bigint NOT NULL, "last_password_change" bigint NOT NULL, "role_id" integer NULL REFERENCES "{{roles}}" ("id") ON DELETE SET NULL); CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY, "folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "sort_order" integer NOT NULL, CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id")); CREATE TABLE "{{users_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY, "user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE NO ACTION, "group_type" integer NOT NULL, "sort_order" integer NOT NULL, CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id")); CREATE TABLE "{{users_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY, "user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "sort_order" integer NOT NULL, CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id")); CREATE TABLE "{{shares}}" ("id" integer NOT NULL PRIMARY KEY, "share_id" varchar(60) NOT NULL UNIQUE, "name" varchar(255) NOT NULL, "description" varchar(512) NULL, "scope" integer NOT NULL, "paths" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL, "password" text NULL, "max_tokens" integer NOT NULL, "used_tokens" integer NOT NULL, "allow_from" text NULL, "options" text NULL, "user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED); CREATE TABLE "{{api_keys}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL, "key_id" varchar(50) NOT NULL UNIQUE, "api_key" varchar(255) NOT NULL UNIQUE, "scope" integer NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL, "description" text NULL, "admin_id" integer NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "user_id" integer NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED); CREATE TABLE "{{events_rules}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL, "description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "trigger" integer NOT NULL, "conditions" text NOT NULL, "deleted_at" bigint NOT NULL); CREATE TABLE "{{events_actions}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "type" integer NOT NULL, "options" text NOT NULL); CREATE TABLE "{{rules_actions_mapping}}" ("id" integer NOT NULL PRIMARY KEY, "rule_id" integer NOT NULL REFERENCES "{{events_rules}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "action_id" integer NOT NULL REFERENCES "{{events_actions}}" ("id") ON DELETE NO ACTION DEFERRABLE INITIALLY DEFERRED, "order" integer NOT NULL, "options" text NOT NULL, CONSTRAINT "{{prefix}}unique_rule_action_mapping" UNIQUE ("rule_id", "action_id")); CREATE TABLE "{{tasks}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "updated_at" bigint NOT NULL, "version" bigint NOT NULL); CREATE TABLE "{{admins_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY, "admin_id" integer NOT NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "options" text NOT NULL, "sort_order" integer NOT NULL, CONSTRAINT "{{prefix}}unique_admin_group_mapping" UNIQUE ("admin_id", "group_id")); CREATE TABLE "{{ip_lists}}" ("id" integer NOT NULL PRIMARY KEY, "type" integer NOT NULL, "ipornet" varchar(50) NOT NULL, "mode" integer NOT NULL, "description" varchar(512) NULL, "first" BLOB NOT NULL, "last" BLOB NOT NULL, "ip_type" integer NOT NULL, "protocols" integer NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "deleted_at" bigint NOT NULL, CONSTRAINT "{{prefix}}unique_ipornet_type_mapping" UNIQUE ("type", "ipornet")); CREATE TABLE "{{configs}}" ("id" integer NOT NULL PRIMARY KEY, "configs" text NOT NULL); INSERT INTO {{configs}} (configs) VALUES ('{}'); CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id"); CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id"); CREATE INDEX "{{prefix}}users_folders_mapping_sort_order_idx" ON "{{users_folders_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id"); CREATE INDEX "{{prefix}}users_groups_mapping_user_id_idx" ON "{{users_groups_mapping}}" ("user_id"); CREATE INDEX "{{prefix}}users_groups_mapping_sort_order_idx" ON "{{users_groups_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}groups_folders_mapping_folder_id_idx" ON "{{groups_folders_mapping}}" ("folder_id"); CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folders_mapping}}" ("group_id"); CREATE INDEX "{{prefix}}groups_folders_mapping_sort_order_idx" ON "{{groups_folders_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id"); CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id"); CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at"); CREATE INDEX "{{prefix}}users_deleted_at_idx" ON "{{users}}" ("deleted_at"); CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id"); CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at"); CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time"); CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time"); CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id"); CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id"); CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id"); CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at"); CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); CREATE INDEX "{{prefix}}events_rules_updated_at_idx" ON "{{events_rules}}" ("updated_at"); CREATE INDEX "{{prefix}}events_rules_deleted_at_idx" ON "{{events_rules}}" ("deleted_at"); CREATE INDEX "{{prefix}}events_rules_trigger_idx" ON "{{events_rules}}" ("trigger"); CREATE INDEX "{{prefix}}rules_actions_mapping_rule_id_idx" ON "{{rules_actions_mapping}}" ("rule_id"); CREATE INDEX "{{prefix}}rules_actions_mapping_action_id_idx" ON "{{rules_actions_mapping}}" ("action_id"); CREATE INDEX "{{prefix}}rules_actions_mapping_order_idx" ON "{{rules_actions_mapping}}" ("order"); CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id"); CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id"); CREATE INDEX "{{prefix}}admins_groups_mapping_sort_order_idx" ON "{{admins_groups_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}users_role_id_idx" ON "{{users}}" ("role_id"); CREATE INDEX "{{prefix}}admins_role_id_idx" ON "{{admins}}" ("role_id"); CREATE INDEX "{{prefix}}ip_lists_type_idx" ON "{{ip_lists}}" ("type"); CREATE INDEX "{{prefix}}ip_lists_ipornet_idx" ON "{{ip_lists}}" ("ipornet"); CREATE INDEX "{{prefix}}ip_lists_ip_type_idx" ON "{{ip_lists}}" ("ip_type"); CREATE INDEX "{{prefix}}ip_lists_ip_updated_at_idx" ON "{{ip_lists}}" ("updated_at"); CREATE INDEX "{{prefix}}ip_lists_ip_deleted_at_idx" ON "{{ip_lists}}" ("deleted_at"); CREATE INDEX "{{prefix}}ip_lists_first_last_idx" ON "{{ip_lists}}" ("first", "last"); INSERT INTO {{schema_version}} (version) VALUES (33); ` sqliteV34SQL = ` CREATE TABLE "{{shares_groups_mapping}}" ( "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "share_id" integer NOT NULL REFERENCES "{{shares}}" ("id") ON DELETE CASCADE, "group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE, "permissions" integer NOT NULL, "sort_order" integer NOT NULL, CONSTRAINT "{{prefix}}unique_share_group_mapping" UNIQUE ("share_id", "group_id") ); CREATE INDEX "{{prefix}}shares_groups_mapping_sort_order_idx" ON "{{shares_groups_mapping}}" ("sort_order"); CREATE INDEX "{{prefix}}shares_groups_mapping_group_id_idx" ON "{{shares_groups_mapping}}" ("group_id"); CREATE INDEX "{{prefix}}shares_groups_mapping_share_id_idx" ON "{{shares_groups_mapping}}" ("share_id"); ` sqliteV34DownSQL = `DROP TABLE IF EXISTS "{{shares_groups_mapping}}";` ) // SQLiteProvider defines the auth provider for SQLite database type SQLiteProvider struct { dbHandle *sql.DB } func init() { version.AddFeature("+sqlite") } func initializeSQLiteProvider(basePath string) error { var connectionString string if config.ConnectionString == "" { dbPath := config.Name if !util.IsFileInputValid(dbPath) { return fmt.Errorf("invalid database path: %q", dbPath) } if !filepath.IsAbs(dbPath) { dbPath = filepath.Join(basePath, dbPath) } connectionString = fmt.Sprintf("file:%s?cache=shared&_foreign_keys=1", dbPath) } else { connectionString = config.ConnectionString } dbHandle, err := sql.Open("sqlite3", connectionString) if err != nil { providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %q, error: %v", connectionString, err) return err } providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %q", connectionString) dbHandle.SetMaxOpenConns(1) provider = &SQLiteProvider{dbHandle: dbHandle} return executePragmaOptimize(dbHandle) } func (p *SQLiteProvider) checkAvailability() error { return sqlCommonCheckAvailability(p.dbHandle) } func (p *SQLiteProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { return sqlCommonValidateUserAndPass(username, password, ip, protocol, p.dbHandle) } func (p *SQLiteProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { return sqlCommonValidateUserAndTLSCertificate(username, protocol, tlsCert, p.dbHandle) } func (p *SQLiteProvider) validateUserAndPubKey(username string, publicKey []byte, isSSHCert bool) (User, string, error) { return sqlCommonValidateUserAndPubKey(username, publicKey, isSSHCert, p.dbHandle) } func (p *SQLiteProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) } func (p *SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) } func (p *SQLiteProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { return sqlCommonGetUsedQuota(username, p.dbHandle) } func (p *SQLiteProvider) getAdminSignature(username string) (string, error) { return sqlCommonGetAdminSignature(username, p.dbHandle) } func (p *SQLiteProvider) getUserSignature(username string) (string, error) { return sqlCommonGetUserSignature(username, p.dbHandle) } func (p *SQLiteProvider) setUpdatedAt(username string) { sqlCommonSetUpdatedAt(username, p.dbHandle) } func (p *SQLiteProvider) updateLastLogin(username string) error { return sqlCommonUpdateLastLogin(username, p.dbHandle) } func (p *SQLiteProvider) updateAdminLastLogin(username string) error { return sqlCommonUpdateAdminLastLogin(username, p.dbHandle) } func (p *SQLiteProvider) userExists(username, role string) (User, error) { return sqlCommonGetUserByUsername(username, role, p.dbHandle) } func (p *SQLiteProvider) addUser(user *User) error { return p.normalizeError(sqlCommonAddUser(user, p.dbHandle), fieldUsername) } func (p *SQLiteProvider) updateUser(user *User) error { return p.normalizeError(sqlCommonUpdateUser(user, p.dbHandle), -1) } func (p *SQLiteProvider) deleteUser(user User, softDelete bool) error { return sqlCommonDeleteUser(user, softDelete, p.dbHandle) } func (p *SQLiteProvider) updateUserPassword(username, password string) error { return sqlCommonUpdateUserPassword(username, password, p.dbHandle) } func (p *SQLiteProvider) dumpUsers() ([]User, error) { return sqlCommonDumpUsers(p.dbHandle) } func (p *SQLiteProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { return sqlCommonGetRecentlyUpdatedUsers(after, p.dbHandle) } func (p *SQLiteProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { return sqlCommonGetUsers(limit, offset, order, role, p.dbHandle) } func (p *SQLiteProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) } func (p *SQLiteProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { return sqlCommonDumpFolders(p.dbHandle) } func (p *SQLiteProvider) getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) { return sqlCommonGetFolders(limit, offset, order, minimal, p.dbHandle) } func (p *SQLiteProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() return sqlCommonGetFolderByName(ctx, name, p.dbHandle) } func (p *SQLiteProvider) addFolder(folder *vfs.BaseVirtualFolder) error { return p.normalizeError(sqlCommonAddFolder(folder, p.dbHandle), fieldName) } func (p *SQLiteProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { return sqlCommonUpdateFolder(folder, p.dbHandle) } func (p *SQLiteProvider) deleteFolder(folder vfs.BaseVirtualFolder) error { return sqlCommonDeleteFolder(folder, p.dbHandle) } func (p *SQLiteProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateFolderQuota(name, filesAdd, sizeAdd, reset, p.dbHandle) } func (p *SQLiteProvider) getUsedFolderQuota(name string) (int, int64, error) { return sqlCommonGetFolderUsedQuota(name, p.dbHandle) } func (p *SQLiteProvider) getGroups(limit, offset int, order string, minimal bool) ([]Group, error) { return sqlCommonGetGroups(limit, offset, order, minimal, p.dbHandle) } func (p *SQLiteProvider) getGroupsWithNames(names []string) ([]Group, error) { return sqlCommonGetGroupsWithNames(names, p.dbHandle) } func (p *SQLiteProvider) getUsersInGroups(names []string) ([]string, error) { return sqlCommonGetUsersInGroups(names, p.dbHandle) } func (p *SQLiteProvider) groupExists(name string) (Group, error) { return sqlCommonGetGroupByName(name, p.dbHandle) } func (p *SQLiteProvider) addGroup(group *Group) error { return p.normalizeError(sqlCommonAddGroup(group, p.dbHandle), fieldName) } func (p *SQLiteProvider) updateGroup(group *Group) error { return sqlCommonUpdateGroup(group, p.dbHandle) } func (p *SQLiteProvider) deleteGroup(group Group) error { return sqlCommonDeleteGroup(group, p.dbHandle) } func (p *SQLiteProvider) dumpGroups() ([]Group, error) { return sqlCommonDumpGroups(p.dbHandle) } func (p *SQLiteProvider) adminExists(username string) (Admin, error) { return sqlCommonGetAdminByUsername(username, p.dbHandle) } func (p *SQLiteProvider) addAdmin(admin *Admin) error { return p.normalizeError(sqlCommonAddAdmin(admin, p.dbHandle), fieldUsername) } func (p *SQLiteProvider) updateAdmin(admin *Admin) error { return p.normalizeError(sqlCommonUpdateAdmin(admin, p.dbHandle), -1) } func (p *SQLiteProvider) deleteAdmin(admin Admin) error { return sqlCommonDeleteAdmin(admin, p.dbHandle) } func (p *SQLiteProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { return sqlCommonGetAdmins(limit, offset, order, p.dbHandle) } func (p *SQLiteProvider) dumpAdmins() ([]Admin, error) { return sqlCommonDumpAdmins(p.dbHandle) } func (p *SQLiteProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { return sqlCommonValidateAdminAndPass(username, password, ip, p.dbHandle) } func (p *SQLiteProvider) apiKeyExists(keyID string) (APIKey, error) { return sqlCommonGetAPIKeyByID(keyID, p.dbHandle) } func (p *SQLiteProvider) addAPIKey(apiKey *APIKey) error { return p.normalizeError(sqlCommonAddAPIKey(apiKey, p.dbHandle), -1) } func (p *SQLiteProvider) updateAPIKey(apiKey *APIKey) error { return p.normalizeError(sqlCommonUpdateAPIKey(apiKey, p.dbHandle), -1) } func (p *SQLiteProvider) deleteAPIKey(apiKey APIKey) error { return sqlCommonDeleteAPIKey(apiKey, p.dbHandle) } func (p *SQLiteProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { return sqlCommonGetAPIKeys(limit, offset, order, p.dbHandle) } func (p *SQLiteProvider) dumpAPIKeys() ([]APIKey, error) { return sqlCommonDumpAPIKeys(p.dbHandle) } func (p *SQLiteProvider) updateAPIKeyLastUse(keyID string) error { return sqlCommonUpdateAPIKeyLastUse(keyID, p.dbHandle) } func (p *SQLiteProvider) shareExists(shareID, username string) (Share, error) { return sqlCommonGetShareByID(shareID, username, p.dbHandle) } func (p *SQLiteProvider) addShare(share *Share) error { return p.normalizeError(sqlCommonAddShare(share, p.dbHandle), fieldName) } func (p *SQLiteProvider) updateShare(share *Share) error { return p.normalizeError(sqlCommonUpdateShare(share, p.dbHandle), -1) } func (p *SQLiteProvider) deleteShare(share Share) error { return sqlCommonDeleteShare(share, p.dbHandle) } func (p *SQLiteProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { return sqlCommonGetShares(limit, offset, order, username, p.dbHandle) } func (p *SQLiteProvider) dumpShares() ([]Share, error) { return sqlCommonDumpShares(p.dbHandle) } func (p *SQLiteProvider) updateShareLastUse(shareID string, numTokens int) error { return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle) } func (p *SQLiteProvider) getDefenderHosts(from int64, limit int) ([]DefenderEntry, error) { return sqlCommonGetDefenderHosts(from, limit, p.dbHandle) } func (p *SQLiteProvider) getDefenderHostByIP(ip string, from int64) (DefenderEntry, error) { return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle) } func (p *SQLiteProvider) isDefenderHostBanned(ip string) (DefenderEntry, error) { return sqlCommonIsDefenderHostBanned(ip, p.dbHandle) } func (p *SQLiteProvider) updateDefenderBanTime(ip string, minutes int) error { return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle) } func (p *SQLiteProvider) deleteDefenderHost(ip string) error { return sqlCommonDeleteDefenderHost(ip, p.dbHandle) } func (p *SQLiteProvider) addDefenderEvent(ip string, score int) error { return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle) } func (p *SQLiteProvider) setDefenderBanTime(ip string, banTime int64) error { return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle) } func (p *SQLiteProvider) cleanupDefender(from int64) error { return sqlCommonDefenderCleanup(from, p.dbHandle) } func (p *SQLiteProvider) addActiveTransfer(transfer ActiveTransfer) error { return sqlCommonAddActiveTransfer(transfer, p.dbHandle) } func (p *SQLiteProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) } func (p *SQLiteProvider) removeActiveTransfer(transferID int64, connectionID string) error { return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) } func (p *SQLiteProvider) cleanupActiveTransfers(before time.Time) error { return sqlCommonCleanupActiveTransfers(before, p.dbHandle) } func (p *SQLiteProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { return sqlCommonGetActiveTransfers(from, p.dbHandle) } func (p *SQLiteProvider) addSharedSession(session Session) error { return sqlCommonAddSession(session, p.dbHandle) } func (p *SQLiteProvider) deleteSharedSession(key string, sessionType SessionType) error { return sqlCommonDeleteSession(key, sessionType, p.dbHandle) } func (p *SQLiteProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { return sqlCommonGetSession(key, sessionType, p.dbHandle) } func (p *SQLiteProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) } func (p *SQLiteProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) { return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle) } func (p *SQLiteProvider) dumpEventActions() ([]BaseEventAction, error) { return sqlCommonDumpEventActions(p.dbHandle) } func (p *SQLiteProvider) eventActionExists(name string) (BaseEventAction, error) { return sqlCommonGetEventActionByName(name, p.dbHandle) } func (p *SQLiteProvider) addEventAction(action *BaseEventAction) error { return p.normalizeError(sqlCommonAddEventAction(action, p.dbHandle), fieldName) } func (p *SQLiteProvider) updateEventAction(action *BaseEventAction) error { return sqlCommonUpdateEventAction(action, p.dbHandle) } func (p *SQLiteProvider) deleteEventAction(action BaseEventAction) error { return sqlCommonDeleteEventAction(action, p.dbHandle) } func (p *SQLiteProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { return sqlCommonGetEventRules(limit, offset, order, p.dbHandle) } func (p *SQLiteProvider) dumpEventRules() ([]EventRule, error) { return sqlCommonDumpEventRules(p.dbHandle) } func (p *SQLiteProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle) } func (p *SQLiteProvider) eventRuleExists(name string) (EventRule, error) { return sqlCommonGetEventRuleByName(name, p.dbHandle) } func (p *SQLiteProvider) addEventRule(rule *EventRule) error { return p.normalizeError(sqlCommonAddEventRule(rule, p.dbHandle), fieldName) } func (p *SQLiteProvider) updateEventRule(rule *EventRule) error { return sqlCommonUpdateEventRule(rule, p.dbHandle) } func (p *SQLiteProvider) deleteEventRule(rule EventRule, softDelete bool) error { return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle) } func (p *SQLiteProvider) getTaskByName(name string) (Task, error) { return sqlCommonGetTaskByName(name, p.dbHandle) } func (p *SQLiteProvider) addTask(name string) error { return sqlCommonAddTask(name, p.dbHandle) } func (p *SQLiteProvider) updateTask(name string, version int64) error { return sqlCommonUpdateTask(name, version, p.dbHandle) } func (p *SQLiteProvider) updateTaskTimestamp(name string) error { return sqlCommonUpdateTaskTimestamp(name, p.dbHandle) } func (*SQLiteProvider) addNode() error { return ErrNotImplemented } func (*SQLiteProvider) getNodeByName(_ string) (Node, error) { return Node{}, ErrNotImplemented } func (*SQLiteProvider) getNodes() ([]Node, error) { return nil, ErrNotImplemented } func (*SQLiteProvider) updateNodeTimestamp() error { return ErrNotImplemented } func (*SQLiteProvider) cleanupNodes() error { return ErrNotImplemented } func (p *SQLiteProvider) roleExists(name string) (Role, error) { return sqlCommonGetRoleByName(name, p.dbHandle) } func (p *SQLiteProvider) addRole(role *Role) error { return p.normalizeError(sqlCommonAddRole(role, p.dbHandle), fieldName) } func (p *SQLiteProvider) updateRole(role *Role) error { return sqlCommonUpdateRole(role, p.dbHandle) } func (p *SQLiteProvider) deleteRole(role Role) error { return sqlCommonDeleteRole(role, p.dbHandle) } func (p *SQLiteProvider) getRoles(limit int, offset int, order string, minimal bool) ([]Role, error) { return sqlCommonGetRoles(limit, offset, order, minimal, p.dbHandle) } func (p *SQLiteProvider) dumpRoles() ([]Role, error) { return sqlCommonDumpRoles(p.dbHandle) } func (p *SQLiteProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { return sqlCommonGetIPListEntry(ipOrNet, listType, p.dbHandle) } func (p *SQLiteProvider) addIPListEntry(entry *IPListEntry) error { return p.normalizeError(sqlCommonAddIPListEntry(entry, p.dbHandle), fieldIPNet) } func (p *SQLiteProvider) updateIPListEntry(entry *IPListEntry) error { return sqlCommonUpdateIPListEntry(entry, p.dbHandle) } func (p *SQLiteProvider) deleteIPListEntry(entry IPListEntry, softDelete bool) error { return sqlCommonDeleteIPListEntry(entry, softDelete, p.dbHandle) } func (p *SQLiteProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { return sqlCommonGetIPListEntries(listType, filter, from, order, limit, p.dbHandle) } func (p *SQLiteProvider) getRecentlyUpdatedIPListEntries(after int64) ([]IPListEntry, error) { return sqlCommonGetRecentlyUpdatedIPListEntries(after, p.dbHandle) } func (p *SQLiteProvider) dumpIPListEntries() ([]IPListEntry, error) { return sqlCommonDumpIPListEntries(p.dbHandle) } func (p *SQLiteProvider) countIPListEntries(listType IPListType) (int64, error) { return sqlCommonCountIPListEntries(listType, p.dbHandle) } func (p *SQLiteProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { return sqlCommonGetListEntriesForIP(ip, listType, p.dbHandle) } func (p *SQLiteProvider) getConfigs() (Configs, error) { return sqlCommonGetConfigs(p.dbHandle) } func (p *SQLiteProvider) setConfigs(configs *Configs) error { return sqlCommonSetConfigs(configs, p.dbHandle) } func (p *SQLiteProvider) setFirstDownloadTimestamp(username string) error { return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle) } func (p *SQLiteProvider) setFirstUploadTimestamp(username string) error { return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle) } func (p *SQLiteProvider) close() error { return p.dbHandle.Close() } func (p *SQLiteProvider) reloadConfig() error { return nil } // initializeDatabase creates the initial database structure func (p *SQLiteProvider) initializeDatabase() error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false) if err == nil && dbVersion.Version > 0 { return ErrNoInitRequired } if errors.Is(err, sql.ErrNoRows) { return errSchemaVersionEmpty } logger.InfoToConsole("creating initial database schema, version 33") providerLog(logger.LevelInfo, "creating initial database schema, version 33") sql := sqlReplaceAll(sqliteInitialSQL) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 33, true) } func (p *SQLiteProvider) migrateDatabase() error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err } switch version := dbVersion.Version; { case version == sqlDatabaseVersion: providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version) return ErrNoInitRequired case version < 33: err = errSchemaVersionTooOld(version) providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err case version == 33: return updateSQLiteDatabaseFromV33(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, sqlDatabaseVersion) logger.WarnToConsole("database schema version %d is newer than the supported one: %d", version, sqlDatabaseVersion) return nil } return fmt.Errorf("database schema version not handled: %d", version) } } func (p *SQLiteProvider) revertDatabase(targetVersion int) error { dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err } if dbVersion.Version == targetVersion { return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { case 34: return downgradeSQLiteDatabaseFromV34(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } } func (p *SQLiteProvider) resetDatabase() error { sql := sqlReplaceAll(sqliteResetSQL) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false) } func (p *SQLiteProvider) normalizeError(err error, fieldType int) error { if err == nil { return nil } if e, ok := err.(sqlite3.Error); ok { switch e.ExtendedCode { case 1555, 2067: var message string switch fieldType { case fieldUsername: message = util.I18nErrorDuplicatedUsername case fieldIPNet: message = util.I18nErrorDuplicatedIPNet default: message = util.I18nErrorDuplicatedName } return util.NewI18nError( fmt.Errorf("%w: %s", ErrDuplicatedKey, err.Error()), message, ) case 787: return fmt.Errorf("%w: %s", ErrForeignKeyViolated, err.Error()) } } return err } func executePragmaOptimize(dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() _, err := dbHandle.ExecContext(ctx, "PRAGMA optimize;") return err } func updateSQLiteDatabaseFromV33(dbHandle *sql.DB) error { return updateSQLiteDatabaseFrom33To34(dbHandle) } func downgradeSQLiteDatabaseFromV34(dbHandle *sql.DB) error { return downgradeSQLiteDatabaseFrom34To33(dbHandle) } func updateSQLiteDatabaseFrom33To34(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 33 -> 34") providerLog(logger.LevelInfo, "updating database schema version: 33 -> 34") sql := strings.ReplaceAll(sqliteV34SQL, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) sql = strings.ReplaceAll(sql, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 34, true) } func downgradeSQLiteDatabaseFrom34To33(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database schema version: 34 -> 33") providerLog(logger.LevelInfo, "downgrading database schema version: 34 -> 33") sql := strings.ReplaceAll(sqliteV34DownSQL, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 33, false) } /*func setPragmaFK(dbHandle *sql.DB, value string) error { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() sql := fmt.Sprintf("PRAGMA foreign_keys=%v;", value) _, err := dbHandle.ExecContext(ctx, sql) return err }*/ ================================================ FILE: internal/dataprovider/sqlite_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build nosqlite || !cgo package dataprovider import ( "errors" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-sqlite") } func initializeSQLiteProvider(_ string) error { return errors.New("SQLite disabled at build time") } ================================================ FILE: internal/dataprovider/sqlqueries.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "fmt" "strconv" "strings" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( selectUserFields = "u.id,u.username,u.password,u.public_keys,u.home_dir,u.uid,u.gid,u.max_sessions,u.quota_size,u.quota_files," + "u.permissions,u.used_quota_size,u.used_quota_files,u.last_quota_update,u.upload_bandwidth,u.download_bandwidth," + "u.expiration_date,u.last_login,u.status,u.filters,u.filesystem,u.additional_info,u.description,u.email,u.created_at," + "u.updated_at,u.upload_data_transfer,u.download_data_transfer,u.total_data_transfer," + "u.used_upload_data_transfer,u.used_download_data_transfer,u.deleted_at,u.first_download,u.first_upload,r.name,u.last_password_change" selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem" selectAdminFields = "a.id,a.username,a.password,a.status,a.email,a.permissions,a.filters,a.additional_info,a.description,a.created_at,a.updated_at,a.last_login,r.name" selectAPIKeyFields = "key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id" selectShareFields = "s.share_id,s.name,s.description,s.scope,s.paths,u.username,s.created_at,s.updated_at,s.last_use_at," + "s.expires_at,s.password,s.max_tokens,s.used_tokens,s.allow_from" selectGroupFields = "id,name,description,created_at,updated_at,user_settings" selectEventActionFields = "id,name,description,type,options" selectRoleFields = "id,name,description,created_at,updated_at" selectIPListEntryFields = "type,ipornet,mode,protocols,description,created_at,updated_at,deleted_at" selectMinimalFields = "id,name" ) func getSQLPlaceholders() []string { var placeholders []string for i := 1; i <= 100; i++ { if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { placeholders = append(placeholders, fmt.Sprintf("$%d", i)) } else { placeholders = append(placeholders, "?") } } return placeholders } func getSQLQuotedName(name string) string { if config.Driver == MySQLDataProviderName { return fmt.Sprintf("`%s`", name) } return fmt.Sprintf(`"%s"`, name) } func getSelectEventRuleFields() string { if config.Driver == MySQLDataProviderName { return "id,name,description,created_at,updated_at,`trigger`,conditions,deleted_at,status" } return `id,name,description,created_at,updated_at,"trigger",conditions,deleted_at,status` } func getCoalesceDefaultForRole(role string) string { if role != "" { return "0" } return "NULL" } func getAddSessionQuery() string { if config.Driver == MySQLDataProviderName { return fmt.Sprintf("INSERT INTO %s (`key`,`data`,`type`,`timestamp`) VALUES (%s,%s,%s,%s) "+ "ON DUPLICATE KEY UPDATE `data`=VALUES(`data`), `timestamp`=VALUES(`timestamp`)", sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } return fmt.Sprintf(`INSERT INTO %s (key,data,type,timestamp) VALUES (%s,%s,%s,%s) ON CONFLICT(key,type) DO UPDATE SET data= EXCLUDED.data, timestamp=EXCLUDED.timestamp`, sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getDeleteSessionQuery() string { if config.Driver == MySQLDataProviderName { return fmt.Sprintf("DELETE FROM %s WHERE `key` = %s AND `type` = %s", sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } return fmt.Sprintf(`DELETE FROM %s WHERE key = %s AND type = %s`, sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } func getSessionQuery() string { if config.Driver == MySQLDataProviderName { return fmt.Sprintf("SELECT `key`,`data`,`type`,`timestamp` FROM %s WHERE `key` = %s AND `type` = %s", sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } return fmt.Sprintf(`SELECT key,data,type,timestamp FROM %s WHERE key = %s AND type = %s`, sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } func getCleanupSessionsQuery() string { return fmt.Sprintf(`DELETE from %s WHERE type = %s AND timestamp < %s`, sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } func getAddDefenderHostQuery() string { if config.Driver == MySQLDataProviderName { return fmt.Sprintf("INSERT INTO %s (`ip`,`updated_at`,`ban_time`) VALUES (%s,%s,0) ON DUPLICATE KEY UPDATE `updated_at`=VALUES(`updated_at`)", sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } return fmt.Sprintf(`INSERT INTO %s (ip,updated_at,ban_time) VALUES (%s,%s,0) ON CONFLICT (ip) DO UPDATE SET updated_at = EXCLUDED.updated_at RETURNING id`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } func getAddDefenderEventQuery() string { return fmt.Sprintf(`INSERT INTO %s (date_time,score,host_id) VALUES (%s,%s,(SELECT id from %s WHERE ip = %s))`, sqlTableDefenderEvents, sqlPlaceholders[0], sqlPlaceholders[1], sqlTableDefenderHosts, sqlPlaceholders[2]) } func getDefenderHostsQuery() string { return fmt.Sprintf(`SELECT id,ip,ban_time FROM %s WHERE updated_at >= %s OR ban_time > 0 ORDER BY updated_at DESC LIMIT %s`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDefenderHostQuery() string { return fmt.Sprintf(`SELECT id,ip,ban_time FROM %s WHERE ip = %s AND (updated_at >= %s OR ban_time > 0)`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDefenderEventsQuery(hostIDs []int64) string { var sb strings.Builder for _, hID := range hostIDs { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(hID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } else { sb.WriteString("(0)") } return fmt.Sprintf(`SELECT host_id,SUM(score) FROM %s WHERE date_time >= %s AND host_id IN %s GROUP BY host_id`, sqlTableDefenderEvents, sqlPlaceholders[0], sb.String()) } func getDefenderIsHostBannedQuery() string { return fmt.Sprintf(`SELECT id FROM %s WHERE ip = %s AND ban_time >= %s`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDefenderIncrementBanTimeQuery() string { return fmt.Sprintf(`UPDATE %s SET ban_time = ban_time + %s WHERE ip = %s`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDefenderSetBanTimeQuery() string { return fmt.Sprintf(`UPDATE %s SET ban_time = %s WHERE ip = %s`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDeleteDefenderHostQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE ip = %s`, sqlTableDefenderHosts, sqlPlaceholders[0]) } func getDefenderHostsCleanupQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE ban_time < %s AND NOT EXISTS ( SELECT id FROM %s WHERE %s.host_id = %s.id AND %s.date_time > %s)`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlTableDefenderEvents, sqlTableDefenderEvents, sqlTableDefenderHosts, sqlTableDefenderEvents, sqlPlaceholders[1]) } func getDefenderEventsCleanupQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE date_time < %s`, sqlTableDefenderEvents, sqlPlaceholders[0]) } func getIPListEntryQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE type = %s AND ipornet = %s AND deleted_at = 0`, selectIPListEntryFields, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1]) } func getIPListEntriesQuery(filter, from, order string, limit int) string { var sb strings.Builder var idx int sb.WriteString("SELECT ") sb.WriteString(selectIPListEntryFields) sb.WriteString(" FROM ") sb.WriteString(sqlTableIPLists) sb.WriteString(" WHERE type = ") sb.WriteString(sqlPlaceholders[idx]) idx++ if from != "" { if order == OrderASC { sb.WriteString(" AND ipornet > ") } else { sb.WriteString(" AND ipornet < ") } sb.WriteString(sqlPlaceholders[idx]) idx++ } if filter != "" { sb.WriteString(" AND ipornet LIKE ") sb.WriteString(sqlPlaceholders[idx]) idx++ } sb.WriteString(" AND deleted_at = 0 ") sb.WriteString(" ORDER BY ipornet ") sb.WriteString(order) if limit > 0 { sb.WriteString(" LIMIT ") sb.WriteString(sqlPlaceholders[idx]) } return sb.String() } func getCountIPListEntriesQuery() string { return fmt.Sprintf(`SELECT count(ipornet) FROM %s WHERE type = %s AND deleted_at = 0`, sqlTableIPLists, sqlPlaceholders[0]) } func getCountAllIPListEntriesQuery() string { return fmt.Sprintf(`SELECT count(ipornet) FROM %s WHERE deleted_at = 0`, sqlTableIPLists) } func getIPListEntriesForIPQueryPg() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE type = %s AND deleted_at = 0 AND %s::inet BETWEEN first AND last`, selectIPListEntryFields, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1]) } func getIPListEntriesForIPQueryNoPg() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE type = %s AND deleted_at = 0 AND ip_type = %s AND %s BETWEEN first AND last`, selectIPListEntryFields, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) } func getRecentlyUpdatedIPListQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE updated_at >= %s OR deleted_at > 0`, selectIPListEntryFields, sqlTableIPLists, sqlPlaceholders[0]) } func getDumpListEntriesQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at = 0`, selectIPListEntryFields, sqlTableIPLists) } func getAddIPListEntryQuery() string { return fmt.Sprintf(`INSERT INTO %s (type,ipornet,first,last,ip_type,protocols,description,mode,created_at,updated_at,deleted_at) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0)`, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9]) } func getUpdateIPListEntryQuery() string { return fmt.Sprintf(`UPDATE %s SET mode=%s,protocols=%s,description=%s,updated_at=%s WHERE type = %s AND ipornet = %s`, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5]) } func getDeleteIPListEntryQuery(softDelete bool) string { if softDelete { return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE type = %s AND ipornet = %s`, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } return fmt.Sprintf(`DELETE FROM %s WHERE type = %s AND ipornet = %s`, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1]) } func getRemoveSoftDeletedIPListEntryQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE type = %s AND ipornet = %s AND deleted_at > 0`, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1]) } func getConfigsQuery() string { return fmt.Sprintf(`SELECT configs FROM %s LIMIT 1`, sqlTableConfigs) } func getUpdateConfigsQuery() string { return fmt.Sprintf(`UPDATE %s SET configs = %s`, sqlTableConfigs, sqlPlaceholders[0]) } func getRoleByNameQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectRoleFields, sqlTableRoles, sqlPlaceholders[0]) } func getRolesQuery(order string, minimal bool) string { var fieldSelection string if minimal { fieldSelection = selectMinimalFields } else { fieldSelection = selectRoleFields } return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %s OFFSET %s`, fieldSelection, sqlTableRoles, order, sqlPlaceholders[0], sqlPlaceholders[1]) } func getUsersWithRolesQuery(roles []Role) string { var sb strings.Builder for _, r := range roles { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(r.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT r.id, u.username FROM %s u INNER JOIN %s r ON u.role_id = r.id WHERE u.role_id IN %s`, sqlTableUsers, sqlTableRoles, sb.String()) } func getAdminsWithRolesQuery(roles []Role) string { var sb strings.Builder for _, r := range roles { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(r.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT r.id, a.username FROM %s a INNER JOIN %s r ON a.role_id = r.id WHERE a.role_id IN %s`, sqlTableAdmins, sqlTableRoles, sb.String()) } func getDumpRolesQuery() string { return fmt.Sprintf(`SELECT %s FROM %s`, selectRoleFields, sqlTableRoles) } func getAddRoleQuery() string { return fmt.Sprintf(`INSERT INTO %s (name,description,created_at,updated_at) VALUES (%s,%s,%s,%s)`, sqlTableRoles, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getUpdateRoleQuery() string { return fmt.Sprintf(`UPDATE %s SET description=%s,updated_at=%s WHERE name = %s`, sqlTableRoles, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) } func getDeleteRoleQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableRoles, sqlPlaceholders[0]) } func getGroupByNameQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectGroupFields, getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0]) } func getGroupsQuery(order string, minimal bool) string { var fieldSelection string if minimal { fieldSelection = selectMinimalFields } else { fieldSelection = selectGroupFields } return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %s OFFSET %s`, fieldSelection, getSQLQuotedName(sqlTableGroups), order, sqlPlaceholders[0], sqlPlaceholders[1]) } func getGroupsWithNamesQuery(numArgs int) string { var sb strings.Builder for idx := 0; idx < numArgs; idx++ { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(sqlPlaceholders[idx]) } if sb.Len() > 0 { sb.WriteString(")") } else { sb.WriteString("('')") } return fmt.Sprintf(`SELECT %s FROM %s WHERE name in %s`, selectGroupFields, getSQLQuotedName(sqlTableGroups), sb.String()) } func getUsersInGroupsQuery(numArgs int) string { var sb strings.Builder for idx := 0; idx < numArgs; idx++ { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(sqlPlaceholders[idx]) } if sb.Len() > 0 { sb.WriteString(")") } else { sb.WriteString("('')") } return fmt.Sprintf(`SELECT username FROM %s WHERE id IN (SELECT user_id from %s WHERE group_id IN (SELECT id FROM %s WHERE name IN %s))`, sqlTableUsers, sqlTableUsersGroupsMapping, getSQLQuotedName(sqlTableGroups), sb.String()) } func getDumpGroupsQuery() string { return fmt.Sprintf(`SELECT %s FROM %s`, selectGroupFields, getSQLQuotedName(sqlTableGroups)) } func getAddGroupQuery() string { return fmt.Sprintf(`INSERT INTO %s (name,description,created_at,updated_at,user_settings) VALUES (%s,%s,%s,%s,%s)`, getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4]) } func getUpdateGroupQuery() string { return fmt.Sprintf(`UPDATE %s SET description=%s,user_settings=%s,updated_at=%s WHERE name = %s`, getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getDeleteGroupQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0]) } func getAdminByUsernameQuery() string { return fmt.Sprintf(`SELECT %s FROM %s a LEFT JOIN %s r on r.id = a.role_id WHERE a.username = %s`, selectAdminFields, sqlTableAdmins, sqlTableRoles, sqlPlaceholders[0]) } func getAdminsQuery(order string) string { return fmt.Sprintf(`SELECT %s FROM %s a LEFT JOIN %s r on r.id = a.role_id ORDER BY a.username %s LIMIT %s OFFSET %s`, selectAdminFields, sqlTableAdmins, sqlTableRoles, order, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDumpAdminsQuery() string { return fmt.Sprintf(`SELECT %s FROM %s a LEFT JOIN %s r on r.id = a.role_id`, selectAdminFields, sqlTableAdmins, sqlTableRoles) } func getAddAdminQuery(role string) string { return fmt.Sprintf(`INSERT INTO %s (username,password,status,email,permissions,filters,additional_info,description,created_at,updated_at,last_login,role_id) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0,COALESCE((SELECT id from %s WHERE name = %s),%s))`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlTableRoles, sqlPlaceholders[10], getCoalesceDefaultForRole(role)) } func getUpdateAdminQuery(role string) string { return fmt.Sprintf(`UPDATE %s SET password=%s,status=%s,email=%s,permissions=%s,filters=%s,additional_info=%s,description=%s,updated_at=%s, role_id=COALESCE((SELECT id from %s WHERE name = %s),%s) WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlTableRoles, sqlPlaceholders[8], getCoalesceDefaultForRole(role), sqlPlaceholders[9]) } func getDeleteAdminQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0]) } func getShareByIDQuery(filterUser bool) string { if filterUser { return fmt.Sprintf(`SELECT %s FROM %s s INNER JOIN %s u ON s.user_id = u.id WHERE s.share_id = %s AND u.username = %s`, selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) } return fmt.Sprintf(`SELECT %s FROM %s s INNER JOIN %s u ON s.user_id = u.id WHERE s.share_id = %s`, selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0]) } func getSharesQuery(order string) string { return fmt.Sprintf(`SELECT %s FROM %s s INNER JOIN %s u ON s.user_id = u.id WHERE u.username = %s ORDER BY s.share_id %s LIMIT %s OFFSET %s`, selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0], order, sqlPlaceholders[1], sqlPlaceholders[2]) } func getDumpSharesQuery() string { return fmt.Sprintf(`SELECT %s FROM %s s INNER JOIN %s u ON s.user_id = u.id`, selectShareFields, sqlTableShares, sqlTableUsers) } func getAddShareQuery() string { return fmt.Sprintf(`INSERT INTO %s (share_id,name,description,scope,paths,created_at,updated_at,last_use_at, expires_at,password,max_tokens,used_tokens,allow_from,user_id) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)`, sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13]) } func getUpdateShareRestoreQuery() string { return fmt.Sprintf(`UPDATE %s SET name=%s,description=%s,scope=%s,paths=%s,created_at=%s,updated_at=%s, last_use_at=%s,expires_at=%s,password=%s,max_tokens=%s,used_tokens=%s,allow_from=%s,user_id=%s WHERE share_id = %s`, sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13]) } func getUpdateShareQuery() string { return fmt.Sprintf(`UPDATE %s SET name=%s,description=%s,scope=%s,paths=%s,updated_at=%s,expires_at=%s, password=%s,max_tokens=%s,allow_from=%s,user_id=%s WHERE share_id = %s`, sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10]) } func getDeleteShareQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE share_id = %s`, sqlTableShares, sqlPlaceholders[0]) } func getAPIKeyByIDQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE key_id = %s`, selectAPIKeyFields, sqlTableAPIKeys, sqlPlaceholders[0]) } func getAPIKeysQuery(order string) string { return fmt.Sprintf(`SELECT %s FROM %s ORDER BY key_id %s LIMIT %s OFFSET %s`, selectAPIKeyFields, sqlTableAPIKeys, order, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDumpAPIKeysQuery() string { return fmt.Sprintf(`SELECT %s FROM %s`, selectAPIKeyFields, sqlTableAPIKeys) } func getAddAPIKeyQuery() string { return fmt.Sprintf(`INSERT INTO %s (key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10]) } func getUpdateAPIKeyQuery() string { return fmt.Sprintf(`UPDATE %s SET name=%s,scope=%s,expires_at=%s,user_id=%s,admin_id=%s,description=%s,updated_at=%s WHERE key_id = %s`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7]) } func getDeleteAPIKeyQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE key_id = %s`, sqlTableAPIKeys, sqlPlaceholders[0]) } func getRelatedUsersForAPIKeysQuery(apiKeys []APIKey) string { var sb strings.Builder for _, k := range apiKeys { if k.userID == 0 { continue } if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(k.userID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } else { sb.WriteString("(0)") } return fmt.Sprintf(`SELECT id,username FROM %s WHERE id IN %s ORDER BY username`, sqlTableUsers, sb.String()) } func getRelatedAdminsForAPIKeysQuery(apiKeys []APIKey) string { var sb strings.Builder for _, k := range apiKeys { if k.adminID == 0 { continue } if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(k.adminID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } else { sb.WriteString("(0)") } return fmt.Sprintf(`SELECT id,username FROM %s WHERE id IN %s ORDER BY username`, sqlTableAdmins, sb.String()) } func getUserByUsernameQuery(role string) string { if role == "" { return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.username = %s AND u.deleted_at = 0`, selectUserFields, sqlTableUsers, sqlTableRoles, sqlPlaceholders[0]) } return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.username = %s AND u.deleted_at = 0 AND u.role_id is NOT NULL AND r.name = %s`, selectUserFields, sqlTableUsers, sqlTableRoles, sqlPlaceholders[0], sqlPlaceholders[1]) } func getUsersQuery(order, role string) string { if role == "" { return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.deleted_at = 0 ORDER BY u.username %s LIMIT %s OFFSET %s`, selectUserFields, sqlTableUsers, sqlTableRoles, order, sqlPlaceholders[0], sqlPlaceholders[1]) } return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.deleted_at = 0 AND u.role_id is NOT NULL AND r.name = %s ORDER BY u.username %s LIMIT %s OFFSET %s`, selectUserFields, sqlTableUsers, sqlTableRoles, sqlPlaceholders[0], order, sqlPlaceholders[1], sqlPlaceholders[2]) } func getUsersForQuotaCheckQuery(numArgs int) string { var sb strings.Builder for idx := 0; idx < numArgs; idx++ { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(sqlPlaceholders[idx]) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size,total_data_transfer,upload_data_transfer, download_data_transfer,used_upload_data_transfer,used_download_data_transfer,filters FROM %s WHERE username IN %s`, sqlTableUsers, sb.String()) } func getRecentlyUpdatedUsersQuery() string { return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.updated_at >= %s OR u.deleted_at > 0`, selectUserFields, sqlTableUsers, sqlTableRoles, sqlPlaceholders[0]) } func getDumpUsersQuery() string { return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.deleted_at = 0`, selectUserFields, sqlTableUsers, sqlTableRoles) } func getDumpFoldersQuery() string { return fmt.Sprintf(`SELECT %s FROM %s`, selectFolderFields, sqlTableFolders) } func getUpdateTransferQuotaQuery(reset bool) string { if reset { return fmt.Sprintf(`UPDATE %s SET used_upload_data_transfer = %s,used_download_data_transfer = %s,last_quota_update = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } return fmt.Sprintf(`UPDATE %s SET used_upload_data_transfer = used_upload_data_transfer + %s, used_download_data_transfer = used_download_data_transfer + %s,last_quota_update = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getUpdateQuotaQuery(reset bool) string { if reset { return fmt.Sprintf(`UPDATE %s SET used_quota_size = %s,used_quota_files = %s,last_quota_update = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } return fmt.Sprintf(`UPDATE %s SET used_quota_size = used_quota_size + %s,used_quota_files = used_quota_files + %s,last_quota_update = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getAdminSignatureQuery() string { return fmt.Sprintf(`SELECT updated_at FROM %s WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0]) } func getUserSignatureQuery() string { return fmt.Sprintf(`SELECT updated_at FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0]) } func getSetUpdateAtQuery() string { return fmt.Sprintf(`UPDATE %s SET updated_at = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) } func getSetFirstUploadQuery() string { return fmt.Sprintf(`UPDATE %s SET first_upload = %s WHERE username = %s AND first_upload = 0`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) } func getSetFirstDownloadQuery() string { return fmt.Sprintf(`UPDATE %s SET first_download = %s WHERE username = %s AND first_download = 0`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) } func getUpdateLastLoginQuery() string { return fmt.Sprintf(`UPDATE %s SET last_login = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) } func getUpdateAdminLastLoginQuery() string { return fmt.Sprintf(`UPDATE %s SET last_login = %s WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1]) } func getUpdateAPIKeyLastUseQuery() string { return fmt.Sprintf(`UPDATE %s SET last_use_at = %s WHERE key_id = %s`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1]) } func getUpdateShareLastUseQuery() string { return fmt.Sprintf(`UPDATE %s SET last_use_at = %s, used_tokens = used_tokens +%s WHERE share_id = %s`, sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) } func getQuotaQuery() string { return fmt.Sprintf(`SELECT used_quota_size,used_quota_files,used_upload_data_transfer, used_download_data_transfer FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0]) } func getAddUserQuery(role string) string { return fmt.Sprintf(`INSERT INTO %s (username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions, used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,status,last_login,expiration_date,filters, filesystem,additional_info,description,email,created_at,updated_at,upload_data_transfer,download_data_transfer,total_data_transfer, used_upload_data_transfer,used_download_data_transfer,deleted_at,first_download,first_upload,role_id,last_password_change) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0,0,0,%s,%s,%s,0,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0,0,0,0,0, COALESCE((SELECT id from %s WHERE name=%s),%s),%s)`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19], sqlPlaceholders[20], sqlPlaceholders[21], sqlPlaceholders[22], sqlPlaceholders[23], sqlTableRoles, sqlPlaceholders[24], getCoalesceDefaultForRole(role), sqlPlaceholders[25]) } func getUpdateUserQuery(role string) string { return fmt.Sprintf(`UPDATE %s SET password=%s,public_keys=%s,home_dir=%s,uid=%s,gid=%s,max_sessions=%s,quota_size=%s, quota_files=%s,permissions=%s,upload_bandwidth=%s,download_bandwidth=%s,status=%s,expiration_date=%s,filters=%s,filesystem=%s, additional_info=%s,description=%s,email=%s,updated_at=%s,upload_data_transfer=%s,download_data_transfer=%s, total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19], sqlPlaceholders[20], sqlPlaceholders[21], sqlTableRoles, sqlPlaceholders[22], getCoalesceDefaultForRole(role), sqlPlaceholders[23], sqlPlaceholders[24]) } func getUpdateUserPasswordQuery() string { return fmt.Sprintf(`UPDATE %s SET password=%s,updated_at=%s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) } func getDeleteUserQuery(softDelete bool) string { if softDelete { return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) } return fmt.Sprintf(`DELETE FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0]) } func getRemoveSoftDeletedUserQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE username = %s AND deleted_at > 0`, sqlTableUsers, sqlPlaceholders[0]) } func getFolderByNameQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0]) } func getAddFolderQuery() string { return fmt.Sprintf(`INSERT INTO %s (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem) VALUES (%s,%s,%s,%s,%s,%s,%s)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6]) } func getUpdateFolderQuery() string { return fmt.Sprintf(`UPDATE %s SET path=%s,description=%s,filesystem=%s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getDeleteFolderQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0]) } func getClearUserGroupMappingQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE user_id = (SELECT id FROM %s WHERE username = %s)`, sqlTableUsersGroupsMapping, sqlTableUsers, sqlPlaceholders[0]) } func getAddUserGroupMappingQuery() string { return fmt.Sprintf(`INSERT INTO %s (user_id,group_id,group_type,sort_order) VALUES ((SELECT id FROM %s WHERE username = %s), (SELECT id FROM %s WHERE name = %s),%s,%s)`, sqlTableUsersGroupsMapping, sqlTableUsers, sqlPlaceholders[0], getSQLQuotedName(sqlTableGroups), sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getClearAdminGroupMappingQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE admin_id = (SELECT id FROM %s WHERE username = %s)`, sqlTableAdminsGroupsMapping, sqlTableAdmins, sqlPlaceholders[0]) } func getAddAdminGroupMappingQuery() string { return fmt.Sprintf(`INSERT INTO %s (admin_id,group_id,options,sort_order) VALUES ((SELECT id FROM %s WHERE username = %s), (SELECT id FROM %s WHERE name = %s),%s,%s)`, sqlTableAdminsGroupsMapping, sqlTableAdmins, sqlPlaceholders[0], getSQLQuotedName(sqlTableGroups), sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getClearGroupFolderMappingQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE group_id = (SELECT id FROM %s WHERE name = %s)`, sqlTableGroupsFoldersMapping, getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0]) } func getAddGroupFolderMappingQuery() string { return fmt.Sprintf(`INSERT INTO %s (virtual_path,quota_size,quota_files,folder_id,group_id,sort_order) VALUES (%s,%s,%s,(SELECT id FROM %s WHERE name = %s),(SELECT id FROM %s WHERE name = %s),%s)`, sqlTableGroupsFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders, sqlPlaceholders[3], getSQLQuotedName(sqlTableGroups), sqlPlaceholders[4], sqlPlaceholders[5]) } func getClearUserFolderMappingQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE user_id = (SELECT id FROM %s WHERE username = %s)`, sqlTableUsersFoldersMapping, sqlTableUsers, sqlPlaceholders[0]) } func getAddUserFolderMappingQuery() string { return fmt.Sprintf(`INSERT INTO %s (virtual_path,quota_size,quota_files,folder_id,user_id,sort_order) VALUES (%s,%s,%s,(SELECT id FROM %s WHERE name = %s),(SELECT id FROM %s WHERE username = %s),%s)`, sqlTableUsersFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders, sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4], sqlPlaceholders[5]) } func getFoldersQuery(order string, minimal bool) string { var fieldSelection string if minimal { fieldSelection = selectMinimalFields } else { fieldSelection = selectFolderFields } return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %s OFFSET %s`, fieldSelection, sqlTableFolders, order, sqlPlaceholders[0], sqlPlaceholders[1]) } func getUpdateFolderQuotaQuery(reset bool) string { if reset { return fmt.Sprintf(`UPDATE %s SET used_quota_size = %s,used_quota_files = %s,last_quota_update = %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } return fmt.Sprintf(`UPDATE %s SET used_quota_size = used_quota_size + %s,used_quota_files = used_quota_files + %s,last_quota_update = %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getQuotaFolderQuery() string { return fmt.Sprintf(`SELECT used_quota_size,used_quota_files FROM %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0]) } func getRelatedGroupsForUsersQuery(users []User) string { var sb strings.Builder for _, u := range users { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(u.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT g.name,ug.group_type,ug.user_id FROM %s g INNER JOIN %s ug ON g.id = ug.group_id WHERE ug.user_id IN %s ORDER BY ug.sort_order`, getSQLQuotedName(sqlTableGroups), sqlTableUsersGroupsMapping, sb.String()) } func getRelatedGroupsForAdminsQuery(admins []Admin) string { var sb strings.Builder for _, a := range admins { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(a.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT g.name,ag.options,ag.admin_id FROM %s g INNER JOIN %s ag ON g.id = ag.group_id WHERE ag.admin_id IN %s ORDER BY ag.sort_order`, getSQLQuotedName(sqlTableGroups), sqlTableAdminsGroupsMapping, sb.String()) } func getRelatedFoldersForUsersQuery(users []User) string { var sb strings.Builder for _, u := range users { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(u.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path, fm.quota_size,fm.quota_files,fm.user_id,f.filesystem,f.description FROM %s f INNER JOIN %s fm ON f.id = fm.folder_id WHERE fm.user_id IN %s ORDER BY fm.sort_order`, sqlTableFolders, sqlTableUsersFoldersMapping, sb.String()) } func getRelatedUsersForFoldersQuery(folders []vfs.BaseVirtualFolder) string { var sb strings.Builder for _, f := range folders { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(f.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT fm.folder_id,u.username FROM %s fm INNER JOIN %s u ON fm.user_id = u.id WHERE fm.folder_id IN %s ORDER BY u.username`, sqlTableUsersFoldersMapping, sqlTableUsers, sb.String()) } func getRelatedGroupsForFoldersQuery(folders []vfs.BaseVirtualFolder) string { var sb strings.Builder for _, f := range folders { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(f.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT fm.folder_id,g.name FROM %s fm INNER JOIN %s g ON fm.group_id = g.id WHERE fm.folder_id IN %s ORDER BY g.name`, sqlTableGroupsFoldersMapping, getSQLQuotedName(sqlTableGroups), sb.String()) } func getRelatedUsersForGroupsQuery(groups []Group) string { var sb strings.Builder for _, g := range groups { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(g.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT um.group_id,u.username FROM %s um INNER JOIN %s u ON um.user_id = u.id WHERE um.group_id IN %s ORDER BY u.username`, sqlTableUsersGroupsMapping, sqlTableUsers, sb.String()) } func getRelatedAdminsForGroupsQuery(groups []Group) string { var sb strings.Builder for _, g := range groups { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(g.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT am.group_id,a.username FROM %s am INNER JOIN %s a ON am.admin_id = a.id WHERE am.group_id IN %s ORDER BY a.username`, sqlTableAdminsGroupsMapping, sqlTableAdmins, sb.String()) } func getRelatedFoldersForGroupsQuery(groups []Group) string { var sb strings.Builder for _, g := range groups { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(g.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path, fm.quota_size,fm.quota_files,fm.group_id,f.filesystem,f.description FROM %s f INNER JOIN %s fm ON f.id = fm.folder_id WHERE fm.group_id IN %s ORDER BY fm.sort_order`, sqlTableFolders, sqlTableGroupsFoldersMapping, sb.String()) } func getActiveTransfersQuery() string { return fmt.Sprintf(`SELECT transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size, current_ul_size,current_dl_size,created_at,updated_at FROM %s WHERE updated_at > %s`, sqlTableActiveTransfers, sqlPlaceholders[0]) } func getAddActiveTransferQuery() string { return fmt.Sprintf(`INSERT INTO %s (transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size, current_ul_size,current_dl_size,created_at,updated_at) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)`, sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10]) } func getUpdateActiveTransferSizesQuery() string { return fmt.Sprintf(`UPDATE %s SET current_ul_size=%s,current_dl_size=%s,updated_at=%s WHERE connection_id = %s AND transfer_id = %s`, sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4]) } func getRemoveActiveTransferQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE connection_id = %s AND transfer_id = %s`, sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1]) } func getCleanupActiveTransfersQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE updated_at < %s`, sqlTableActiveTransfers, sqlPlaceholders[0]) } func getRelatedRulesForActionsQuery(actions []BaseEventAction) string { var sb strings.Builder for _, a := range actions { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(a.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT am.action_id,r.name FROM %s am INNER JOIN %s r ON am.rule_id = r.id WHERE am.action_id IN %s ORDER BY r.name ASC`, sqlTableRulesActionsMapping, sqlTableEventsRules, sb.String()) } func getEventsActionsQuery(order string, minimal bool) string { var fieldSelection string if minimal { fieldSelection = selectMinimalFields } else { fieldSelection = selectEventActionFields } return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %s OFFSET %s`, fieldSelection, sqlTableEventsActions, order, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDumpEventActionsQuery() string { return fmt.Sprintf(`SELECT %s FROM %s`, selectEventActionFields, sqlTableEventsActions) } func getEventActionByNameQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectEventActionFields, sqlTableEventsActions, sqlPlaceholders[0]) } func getAddEventActionQuery() string { return fmt.Sprintf(`INSERT INTO %s (name,description,type,options) VALUES (%s,%s,%s,%s)`, sqlTableEventsActions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getUpdateEventActionQuery() string { return fmt.Sprintf(`UPDATE %s SET description=%s,type=%s,options=%s WHERE name = %s`, sqlTableEventsActions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getDeleteEventActionQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableEventsActions, sqlPlaceholders[0]) } func getEventRulesQuery(order string) string { return fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at = 0 ORDER BY name %s LIMIT %s OFFSET %s`, getSelectEventRuleFields(), sqlTableEventsRules, order, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDumpEventRulesQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at = 0`, getSelectEventRuleFields(), sqlTableEventsRules) } func getRecentlyUpdatedRulesQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE updated_at >= %s OR deleted_at > 0`, getSelectEventRuleFields(), sqlTableEventsRules, sqlPlaceholders[0]) } func getEventRulesByNameQuery() string { return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s AND deleted_at = 0`, getSelectEventRuleFields(), sqlTableEventsRules, sqlPlaceholders[0]) } func getAddEventRuleQuery() string { return fmt.Sprintf(`INSERT INTO %s (name,description,created_at,updated_at,%s,conditions,deleted_at,status) VALUES (%s,%s,%s,%s,%s,%s,0,%s)`, sqlTableEventsRules, getSQLQuotedName("trigger"), sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6]) } func getUpdateEventRuleQuery() string { return fmt.Sprintf(`UPDATE %s SET description=%s,updated_at=%s,%s=%s,conditions=%s,status=%s WHERE name = %s`, sqlTableEventsRules, sqlPlaceholders[0], sqlPlaceholders[1], getSQLQuotedName("trigger"), sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5]) } func getDeleteEventRuleQuery(softDelete bool) string { if softDelete { return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE name = %s`, sqlTableEventsRules, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) } return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableEventsRules, sqlPlaceholders[0]) } func getRemoveSoftDeletedRuleQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE name = %s AND deleted_at > 0`, sqlTableEventsRules, sqlPlaceholders[0]) } func getClearRuleActionMappingQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE rule_id = (SELECT id FROM %s WHERE name = %s)`, sqlTableRulesActionsMapping, sqlTableEventsRules, sqlPlaceholders[0]) } func getUpdateRulesTimestampQuery() string { return fmt.Sprintf(`UPDATE %s SET updated_at=%s WHERE id IN (SELECT rule_id FROM %s WHERE action_id = (SELECT id from %s WHERE name = %s))`, sqlTableEventsRules, sqlPlaceholders[0], sqlTableRulesActionsMapping, sqlTableEventsActions, sqlPlaceholders[1]) } func getRelatedActionsForRulesQuery(rules []EventRule) string { var sb strings.Builder for _, r := range rules { if sb.Len() == 0 { sb.WriteString("(") } else { sb.WriteString(",") } sb.WriteString(strconv.FormatInt(r.ID, 10)) } if sb.Len() > 0 { sb.WriteString(")") } return fmt.Sprintf(`SELECT a.id,a.name,a.description,a.type,a.options,am.options,am.%s, am.rule_id FROM %s a INNER JOIN %s am ON a.id = am.action_id WHERE am.rule_id IN %s ORDER BY am.%s ASC`, getSQLQuotedName("order"), sqlTableEventsActions, sqlTableRulesActionsMapping, sb.String(), getSQLQuotedName("order")) } func getAddRuleActionMappingQuery() string { return fmt.Sprintf(`INSERT INTO %s (rule_id,action_id,%s,options) VALUES ((SELECT id FROM %s WHERE name = %s), (SELECT id FROM %s WHERE name = %s),%s,%s)`, sqlTableRulesActionsMapping, getSQLQuotedName("order"), sqlTableEventsRules, sqlPlaceholders[0], sqlTableEventsActions, sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getTaskByNameQuery() string { return fmt.Sprintf(`SELECT updated_at,version FROM %s WHERE name = %s`, sqlTableTasks, sqlPlaceholders[0]) } func getAddTaskQuery() string { return fmt.Sprintf(`INSERT INTO %s (name,updated_at,version) VALUES (%s,%s,0)`, sqlTableTasks, sqlPlaceholders[0], sqlPlaceholders[1]) } func getUpdateTaskQuery() string { return fmt.Sprintf(`UPDATE %s SET updated_at=%s,version = version + 1 WHERE name = %s AND version = %s`, sqlTableTasks, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) } func getUpdateTaskTimestampQuery() string { return fmt.Sprintf(`UPDATE %s SET updated_at=%s WHERE name = %s`, sqlTableTasks, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDeleteTaskQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableTasks, sqlPlaceholders[0]) } func getAddNodeQuery() string { if config.Driver == MySQLDataProviderName { return fmt.Sprintf("INSERT INTO %s (`name`,`data`,created_at,`updated_at`) VALUES (%s,%s,%s,%s) ON DUPLICATE KEY UPDATE "+ "`data`=VALUES(`data`), `created_at`=VALUES(`created_at`), `updated_at`=VALUES(`updated_at`)", sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } return fmt.Sprintf(`INSERT INTO %s (name,data,created_at,updated_at) VALUES (%s,%s,%s,%s) ON CONFLICT(name) DO UPDATE SET data=EXCLUDED.data, created_at=EXCLUDED.created_at, updated_at=EXCLUDED.updated_at`, sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getUpdateNodeTimestampQuery() string { return fmt.Sprintf(`UPDATE %s SET updated_at=%s WHERE name = %s`, sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1]) } func getNodeByNameQuery() string { return fmt.Sprintf(`SELECT name,data,created_at,updated_at FROM %s WHERE name = %s AND updated_at > %s`, sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1]) } func getNodesQuery() string { return fmt.Sprintf(`SELECT name,data,created_at,updated_at FROM %s WHERE name != %s AND updated_at > %s`, sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1]) } func getCleanupNodesQuery() string { return fmt.Sprintf(`DELETE FROM %s WHERE updated_at < %s`, sqlTableNodes, sqlPlaceholders[0]) } func getDatabaseVersionQuery() string { return fmt.Sprintf("SELECT version from %s LIMIT 1", sqlTableSchemaVersion) } func getUpdateDBVersionQuery() string { return fmt.Sprintf(`UPDATE %s SET version=%s`, sqlTableSchemaVersion, sqlPlaceholders[0]) } ================================================ FILE: internal/dataprovider/unixcrypt.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build unixcrypt && cgo package dataprovider import ( "strings" "github.com/amoghe/go-crypt" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("+unixcrypt") } func compareYescryptPassword(hashedPwd, plainPwd string) (bool, error) { lastIdx := strings.LastIndex(hashedPwd, "$") pwd, err := crypt.Crypt(plainPwd, hashedPwd[:lastIdx+1]) if err != nil { return false, err } return pwd == hashedPwd, nil } ================================================ FILE: internal/dataprovider/unixcrypt_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !unixcrypt || !cgo package dataprovider import ( "errors" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-unixcrypt") } func compareYescryptPassword(_, _ string) (bool, error) { return false, errors.New("yescrypt hash format is not supported or disabled") } ================================================ FILE: internal/dataprovider/user.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package dataprovider import ( "encoding/json" "errors" "fmt" "math" "net" "os" "path" "path/filepath" "slices" "strconv" "strings" "time" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) // Available permissions for SFTPGo users const ( // All permissions are granted PermAny = "*" // List items such as files and directories is allowed PermListItems = "list" // download files is allowed PermDownload = "download" // upload files is allowed PermUpload = "upload" // overwrite an existing file, while uploading, is allowed // upload permission is required to allow file overwrite PermOverwrite = "overwrite" // delete files or directories is allowed PermDelete = "delete" // delete files is allowed PermDeleteFiles = "delete_files" // delete directories is allowed PermDeleteDirs = "delete_dirs" // rename files or directories is allowed PermRename = "rename" // rename files is allowed PermRenameFiles = "rename_files" // rename directories is allowed PermRenameDirs = "rename_dirs" // create directories is allowed PermCreateDirs = "create_dirs" // create symbolic links is allowed PermCreateSymlinks = "create_symlinks" // changing file or directory permissions is allowed PermChmod = "chmod" // changing file or directory owner and group is allowed PermChown = "chown" // changing file or directory access and modification time is allowed PermChtimes = "chtimes" // copying files or directories is allowed PermCopy = "copy" ) // Available login methods const ( LoginMethodNoAuthTried = "no_auth_tried" LoginMethodPassword = "password" SSHLoginMethodPassword = "password-over-SSH" SSHLoginMethodPublicKey = "publickey" SSHLoginMethodKeyboardInteractive = "keyboard-interactive" SSHLoginMethodKeyAndPassword = "publickey+password" SSHLoginMethodKeyAndKeyboardInt = "publickey+keyboard-interactive" LoginMethodTLSCertificate = "TLSCertificate" LoginMethodTLSCertificateAndPwd = "TLSCertificate+password" LoginMethodIDP = "IDP" ) var ( errNoMatchingVirtualFolder = errors.New("no matching virtual folder found") permsRenameAny = []string{PermRename, PermRenameDirs, PermRenameFiles} permsDeleteAny = []string{PermDelete, PermDeleteDirs, PermDeleteFiles} ) // RecoveryCode defines a 2FA recovery code type RecoveryCode struct { Secret *kms.Secret `json:"secret"` Used bool `json:"used,omitempty"` } // UserTOTPConfig defines the time-based one time password configuration type UserTOTPConfig struct { Enabled bool `json:"enabled,omitempty"` ConfigName string `json:"config_name,omitempty"` Secret *kms.Secret `json:"secret,omitempty"` // TOTP will be required for the specified protocols. // SSH protocol (SFTP/SCP/SSH commands) will ask for the TOTP passcode if the client uses keyboard interactive // authentication. // FTP have no standard way to support two factor authentication, if you // enable the support for this protocol you have to add the TOTP passcode after the password. // For example if your password is "password" and your one time passcode is // "123456" you have to use "password123456" as password. Protocols []string `json:"protocols,omitempty"` } // UserFilters defines additional restrictions for a user // TODO: rename to UserOptions in v3 type UserFilters struct { sdk.BaseUserFilters // User must change password from WebClient/REST API at next login. RequirePasswordChange bool `json:"require_password_change,omitempty"` // AdditionalEmails defines additional email addresses AdditionalEmails []string `json:"additional_emails,omitempty"` // Time-based one time passwords configuration TOTPConfig UserTOTPConfig `json:"totp_config,omitempty"` // Recovery codes to use if the user loses access to their second factor auth device. // Each code can only be used once, you should use these codes to login and disable or // reset 2FA for your account RecoveryCodes []RecoveryCode `json:"recovery_codes,omitempty"` } // User defines a SFTPGo user type User struct { sdk.BaseUser // Additional restrictions Filters UserFilters `json:"filters"` // Mapping between virtual paths and virtual folders VirtualFolders []vfs.VirtualFolder `json:"virtual_folders,omitempty"` // Filesystem configuration details FsConfig vfs.Filesystem `json:"filesystem"` // groups associated with this user Groups []sdk.GroupMapping `json:"groups,omitempty"` // we store the filesystem here using the base path as key. fsCache map[string]vfs.Fs `json:"-"` // true if group settings are already applied for this user groupSettingsApplied bool `json:"-"` // in multi node setups we mark the user as deleted to be able to update the webdav cache DeletedAt int64 `json:"-"` } // GetFilesystem returns the base filesystem for this user func (u *User) GetFilesystem(connectionID string) (fs vfs.Fs, err error) { return u.GetFilesystemForPath("/", connectionID) } func (u *User) getRootFs(connectionID string) (fs vfs.Fs, err error) { switch u.FsConfig.Provider { case sdk.S3FilesystemProvider: return vfs.NewS3Fs(connectionID, u.GetHomeDir(), "", u.FsConfig.S3Config) case sdk.GCSFilesystemProvider: return vfs.NewGCSFs(connectionID, u.GetHomeDir(), "", u.FsConfig.GCSConfig) case sdk.AzureBlobFilesystemProvider: return vfs.NewAzBlobFs(connectionID, u.GetHomeDir(), "", u.FsConfig.AzBlobConfig) case sdk.CryptedFilesystemProvider: return vfs.NewCryptFs(connectionID, u.GetHomeDir(), "", u.FsConfig.CryptConfig) case sdk.SFTPFilesystemProvider: forbiddenSelfUsers, err := u.getForbiddenSFTPSelfUsers(u.FsConfig.SFTPConfig.Username) if err != nil { return nil, err } forbiddenSelfUsers = append(forbiddenSelfUsers, u.Username) return vfs.NewSFTPFs(connectionID, "", u.GetHomeDir(), forbiddenSelfUsers, u.FsConfig.SFTPConfig) case sdk.HTTPFilesystemProvider: return vfs.NewHTTPFs(connectionID, u.GetHomeDir(), "", u.FsConfig.HTTPConfig) default: return vfs.NewOsFs(connectionID, u.GetHomeDir(), "", &u.FsConfig.OSConfig), nil } } func (u *User) checkDirWithParents(virtualDirPath, connectionID string) error { dirs := util.GetDirsForVirtualPath(virtualDirPath) for idx := len(dirs) - 1; idx >= 0; idx-- { vPath := dirs[idx] if vPath == "/" { continue } fs, err := u.GetFilesystemForPath(vPath, connectionID) if err != nil { return fmt.Errorf("unable to get fs for path %q: %w", vPath, err) } if fs.HasVirtualFolders() { continue } fsPath, err := fs.ResolvePath(vPath) if err != nil { return fmt.Errorf("unable to resolve path %q: %w", vPath, err) } _, err = fs.Stat(fsPath) if err == nil { continue } if fs.IsNotExist(err) { err = fs.Mkdir(fsPath) if err != nil { return err } vfs.SetPathPermissions(fs, fsPath, u.GetUID(), u.GetGID()) } else { return fmt.Errorf("unable to stat path %q: %w", vPath, err) } } return nil } func (u *User) checkLocalHomeDir(connectionID string) { switch u.FsConfig.Provider { case sdk.LocalFilesystemProvider, sdk.CryptedFilesystemProvider: return default: osFs := vfs.NewOsFs(connectionID, u.GetHomeDir(), "", nil) osFs.CheckRootPath(u.Username, u.GetUID(), u.GetGID()) } } func (u *User) checkRootPath(connectionID string) error { fs, err := u.GetFilesystemForPath("/", connectionID) if err != nil { logger.Warn(logSender, connectionID, "could not create main filesystem for user %q err: %v", u.Username, err) return fmt.Errorf("could not create root filesystem: %w", err) } fs.CheckRootPath(u.Username, u.GetUID(), u.GetGID()) return nil } // CheckFsRoot check the root directory for the main fs and the virtual folders. // It returns an error if the main filesystem cannot be created func (u *User) CheckFsRoot(connectionID string) error { if u.Filters.DisableFsChecks { return nil } delay := lastLoginMinDelay if u.Filters.ExternalAuthCacheTime > 0 { cacheTime := time.Duration(u.Filters.ExternalAuthCacheTime) * time.Second if cacheTime > delay { delay = cacheTime } } if isLastActivityRecent(u.LastLogin, delay) { if u.LastLogin > u.UpdatedAt { if config.IsShared == 1 { u.checkLocalHomeDir(connectionID) } return nil } } err := u.checkRootPath(connectionID) if err != nil { return err } if u.Filters.StartDirectory != "" { err = u.checkDirWithParents(u.Filters.StartDirectory, connectionID) if err != nil { logger.Warn(logSender, connectionID, "could not create start directory %q, err: %v", u.Filters.StartDirectory, err) } } for idx := range u.VirtualFolders { v := &u.VirtualFolders[idx] fs, err := u.GetFilesystemForPath(v.VirtualPath, connectionID) if err == nil { fs.CheckRootPath(u.Username, u.GetUID(), u.GetGID()) } // now check intermediary folders err = u.checkDirWithParents(path.Dir(v.VirtualPath), connectionID) if err != nil { logger.Warn(logSender, connectionID, "could not create intermediary dir to %q, err: %v", v.VirtualPath, err) } } return nil } // GetCleanedPath returns a clean POSIX absolute path using the user start directory as base // if the provided rawVirtualPath is relative func (u *User) GetCleanedPath(rawVirtualPath string) string { if u.Filters.StartDirectory != "" { if !path.IsAbs(rawVirtualPath) { var b strings.Builder b.Grow(len(u.Filters.StartDirectory) + 1 + len(rawVirtualPath)) b.WriteString(u.Filters.StartDirectory) b.WriteString("/") b.WriteString(rawVirtualPath) return util.CleanPath(b.String()) } } return util.CleanPath(rawVirtualPath) } // isFsEqual returns true if the filesystem configurations are the same func (u *User) isFsEqual(other *User) bool { if u.FsConfig.Provider == sdk.LocalFilesystemProvider && u.GetHomeDir() != other.GetHomeDir() { return false } if !u.FsConfig.IsEqual(other.FsConfig) { return false } if u.Filters.StartDirectory != other.Filters.StartDirectory { return false } if len(u.VirtualFolders) != len(other.VirtualFolders) { return false } for idx := range u.VirtualFolders { f := &u.VirtualFolders[idx] found := false for idx1 := range other.VirtualFolders { f1 := &other.VirtualFolders[idx1] if f.VirtualPath == f1.VirtualPath { found = true if f.FsConfig.Provider == sdk.LocalFilesystemProvider && f.MappedPath != f1.MappedPath { return false } if !f.FsConfig.IsEqual(f1.FsConfig) { return false } } } if !found { return false } } return true } func (u *User) isTimeBasedAccessAllowed(when time.Time) bool { if len(u.Filters.AccessTime) == 0 { return true } if when.IsZero() { when = time.Now() } if UseLocalTime() { when = when.Local() } else { when = when.UTC() } weekDay := when.Weekday() hhMM := when.Format("15:04") for _, p := range u.Filters.AccessTime { if p.DayOfWeek == int(weekDay) { if hhMM >= p.From && hhMM <= p.To { return true } } } return false } // CheckLoginConditions checks user access restrictions func (u *User) CheckLoginConditions() error { if u.Status < 1 { return fmt.Errorf("user %q is disabled", u.Username) } if u.ExpirationDate > 0 && u.ExpirationDate < util.GetTimeAsMsSinceEpoch(time.Now()) { return fmt.Errorf("user %q is expired, expiration timestamp: %v current timestamp: %v", u.Username, u.ExpirationDate, util.GetTimeAsMsSinceEpoch(time.Now())) } if u.isTimeBasedAccessAllowed(time.Now()) { return nil } return errors.New("access is not allowed at this time") } // hideConfidentialData hides user confidential data func (u *User) hideConfidentialData() { u.Password = "" u.FsConfig.HideConfidentialData() if u.Filters.TOTPConfig.Secret != nil { u.Filters.TOTPConfig.Secret.Hide() } for _, code := range u.Filters.RecoveryCodes { if code.Secret != nil { code.Secret.Hide() } } } // CheckMaxShareExpiration returns an error if the share expiration exceed the // maximum allowed date. func (u *User) CheckMaxShareExpiration(expiresAt time.Time) error { if u.Filters.MaxSharesExpiration == 0 { return nil } maxAllowedExpiration := time.Now().Add(24 * time.Hour * time.Duration(u.Filters.MaxSharesExpiration+1)) maxAllowedExpiration = time.Date(maxAllowedExpiration.Year(), maxAllowedExpiration.Month(), maxAllowedExpiration.Day(), 0, 0, 0, 0, maxAllowedExpiration.Location()) if util.GetTimeAsMsSinceEpoch(expiresAt) == 0 || expiresAt.After(maxAllowedExpiration) { return util.NewValidationError(fmt.Sprintf("the share must expire before %s", maxAllowedExpiration.Format(time.DateOnly))) } return nil } // GetEmailAddresses returns all the email addresses. func (u *User) GetEmailAddresses() []string { var res []string if u.Email != "" { res = append(res, u.Email) } return slices.Concat(res, u.Filters.AdditionalEmails) } // GetSubDirPermissions returns permissions for sub directories func (u *User) GetSubDirPermissions() []sdk.DirectoryPermissions { var result []sdk.DirectoryPermissions for k, v := range u.Permissions { if k == "/" { continue } dirPerms := sdk.DirectoryPermissions{ Path: k, Permissions: v, } result = append(result, dirPerms) } return result } func (u *User) setAnonymousSettings() { for k := range u.Permissions { u.Permissions[k] = []string{PermListItems, PermDownload} } u.Filters.DeniedProtocols = append(u.Filters.DeniedProtocols, protocolSSH, protocolHTTP) u.Filters.DeniedProtocols = util.RemoveDuplicates(u.Filters.DeniedProtocols, false) for _, method := range ValidLoginMethods { if method != LoginMethodPassword { u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, method) } } u.Filters.DeniedLoginMethods = util.RemoveDuplicates(u.Filters.DeniedLoginMethods, false) } // RenderAsJSON implements the renderer interface used within plugins func (u *User) RenderAsJSON(reload bool) ([]byte, error) { if reload { user, err := provider.userExists(u.Username, "") if err != nil { providerLog(logger.LevelError, "unable to reload user before rendering as json: %v", err) return nil, err } user.PrepareForRendering() return json.Marshal(user) } u.PrepareForRendering() return json.Marshal(u) } // PrepareForRendering prepares a user for rendering. // It hides confidential data and set to nil the empty secrets // so they are not serialized func (u *User) PrepareForRendering() { u.hideConfidentialData() u.FsConfig.SetNilSecretsIfEmpty() for idx := range u.VirtualFolders { folder := &u.VirtualFolders[idx] folder.PrepareForRendering() } } // HasRedactedSecret returns true if the user has a redacted secret func (u *User) hasRedactedSecret() bool { if u.FsConfig.HasRedactedSecret() { return true } for idx := range u.VirtualFolders { folder := &u.VirtualFolders[idx] if folder.HasRedactedSecret() { return true } } return u.Filters.TOTPConfig.Secret.IsRedacted() } // CloseFs closes the underlying filesystems func (u *User) CloseFs() error { if u.fsCache == nil { return nil } var err error for _, fs := range u.fsCache { errClose := fs.Close() if err == nil { err = errClose } } return err } // IsPasswordHashed returns true if the password is hashed func (u *User) IsPasswordHashed() bool { return util.IsStringPrefixInSlice(u.Password, hashPwdPrefixes) } // IsTLSVerificationEnabled returns true if we need to check the TLS authentication func (u *User) IsTLSVerificationEnabled() bool { if len(u.Filters.TLSCerts) > 0 { return true } if u.Filters.TLSUsername != "" { return u.Filters.TLSUsername != sdk.TLSUsernameNone } return false } // SetEmptySecrets sets to empty any user secret func (u *User) SetEmptySecrets() { u.FsConfig.SetEmptySecrets() for idx := range u.VirtualFolders { folder := &u.VirtualFolders[idx] folder.FsConfig.SetEmptySecrets() } u.Filters.TOTPConfig.Secret = kms.NewEmptySecret() } // GetPermissionsForPath returns the permissions for the given path. // The path must be a SFTPGo virtual path func (u *User) GetPermissionsForPath(p string) []string { permissions := []string{} if perms, ok := u.Permissions["/"]; ok { // if only root permissions are defined returns them unconditionally if len(u.Permissions) == 1 { return perms } // fallback permissions permissions = perms } dirsForPath := util.GetDirsForVirtualPath(p) // dirsForPath contains all the dirs for a given path in reverse order // for example if the path is: /1/2/3/4 it contains: // [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ] // so the first match is the one we are interested to for idx := range dirsForPath { if perms, ok := u.Permissions[dirsForPath[idx]]; ok { return perms } for dir, perms := range u.Permissions { if match, err := path.Match(dir, dirsForPath[idx]); err == nil && match { return perms } } } return permissions } func (u *User) getForbiddenSFTPSelfUsers(username string) ([]string, error) { if allowSelfConnections == 0 { return nil, nil } sftpUser, err := UserExists(username, "") if err == nil { err = sftpUser.LoadAndApplyGroupSettings() } if err == nil { // we don't allow local nested SFTP folders var forbiddens []string if sftpUser.FsConfig.Provider == sdk.SFTPFilesystemProvider { forbiddens = append(forbiddens, sftpUser.Username) return forbiddens, nil } for idx := range sftpUser.VirtualFolders { v := &sftpUser.VirtualFolders[idx] if v.FsConfig.Provider == sdk.SFTPFilesystemProvider { forbiddens = append(forbiddens, sftpUser.Username) return forbiddens, nil } } return forbiddens, nil } if !errors.Is(err, util.ErrNotFound) { return nil, err } return nil, nil } // GetFsConfigForPath returns the file system configuration for the specified virtual path func (u *User) GetFsConfigForPath(virtualPath string) vfs.Filesystem { if virtualPath != "" && virtualPath != "/" && len(u.VirtualFolders) > 0 { folder, err := u.GetVirtualFolderForPath(virtualPath) if err == nil { return folder.FsConfig } } return u.FsConfig } // GetFilesystemForPath returns the filesystem for the given path func (u *User) GetFilesystemForPath(virtualPath, connectionID string) (vfs.Fs, error) { if u.fsCache == nil { u.fsCache = make(map[string]vfs.Fs) } // allow to override the `/` path with a virtual folder if len(u.VirtualFolders) > 0 { folder, err := u.GetVirtualFolderForPath(virtualPath) if err == nil { if fs, ok := u.fsCache[folder.VirtualPath]; ok { return fs, nil } forbiddenSelfUsers := []string{u.Username} if folder.FsConfig.Provider == sdk.SFTPFilesystemProvider { forbiddens, err := u.getForbiddenSFTPSelfUsers(folder.FsConfig.SFTPConfig.Username) if err != nil { return nil, err } forbiddenSelfUsers = append(forbiddenSelfUsers, forbiddens...) } fs, err := folder.GetFilesystem(connectionID, forbiddenSelfUsers) if err == nil { u.fsCache[folder.VirtualPath] = fs } return fs, err } } if val, ok := u.fsCache["/"]; ok { return val, nil } fs, err := u.getRootFs(connectionID) if err != nil { return fs, err } u.fsCache["/"] = fs return fs, err } // GetVirtualFolderForPath returns the virtual folder containing the specified virtual path. // If the path is not inside a virtual folder an error is returned func (u *User) GetVirtualFolderForPath(virtualPath string) (vfs.VirtualFolder, error) { var folder vfs.VirtualFolder if len(u.VirtualFolders) == 0 { return folder, errNoMatchingVirtualFolder } dirsForPath := util.GetDirsForVirtualPath(virtualPath) for index := range dirsForPath { for idx := range u.VirtualFolders { v := &u.VirtualFolders[idx] if v.VirtualPath == dirsForPath[index] { return *v, nil } } } return folder, errNoMatchingVirtualFolder } // ScanQuota scans the user home dir and virtual folders, included in its quota, // and returns the number of files and their size func (u *User) ScanQuota() (int, int64, error) { fs, err := u.getRootFs(xid.New().String()) if err != nil { return 0, 0, err } defer fs.Close() numFiles, size, err := fs.ScanRootDirContents() if err != nil { return numFiles, size, err } for idx := range u.VirtualFolders { v := &u.VirtualFolders[idx] if !v.IsIncludedInUserQuota() { continue } num, s, err := v.ScanQuota() if err != nil { return numFiles, size, err } numFiles += num size += s } return numFiles, size, nil } // GetVirtualFoldersInPath returns the virtual folders inside virtualPath including // any parents func (u *User) GetVirtualFoldersInPath(virtualPath string) map[string]bool { result := make(map[string]bool) for idx := range u.VirtualFolders { dirsForPath := util.GetDirsForVirtualPath(u.VirtualFolders[idx].VirtualPath) for index := range dirsForPath { d := dirsForPath[index] if d == "/" { continue } if path.Dir(d) == virtualPath { result[d] = true } } } if u.Filters.StartDirectory != "" { dirsForPath := util.GetDirsForVirtualPath(u.Filters.StartDirectory) for index := range dirsForPath { d := dirsForPath[index] if d == "/" { continue } if path.Dir(d) == virtualPath { result[d] = true } } } return result } func (u *User) hasVirtualDirs() bool { if u.Filters.StartDirectory != "" { return true } numFolders := len(u.VirtualFolders) if numFolders == 1 { return u.VirtualFolders[0].VirtualPath != "/" } return numFolders > 0 } // GetVirtualFoldersInfo returns []os.FileInfo for virtual folders func (u *User) GetVirtualFoldersInfo(virtualPath string) []os.FileInfo { filter := u.getPatternsFilterForPath(virtualPath) if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide { return nil } vdirs := u.GetVirtualFoldersInPath(virtualPath) result := make([]os.FileInfo, 0, len(vdirs)) for dir := range u.GetVirtualFoldersInPath(virtualPath) { dirName := path.Base(dir) if filter.DenyPolicy == sdk.DenyPolicyHide { if !filter.CheckAllowed(dirName) { continue } } result = append(result, vfs.NewFileInfo(dirName, true, 0, time.Unix(0, 0), false)) } return result } // FilterListDir removes hidden items from the given files list func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os.FileInfo { filter := u.getPatternsFilterForPath(virtualPath) if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide { return dirContents } vdirs := make(map[string]bool) for dir := range u.GetVirtualFoldersInPath(virtualPath) { dirName := path.Base(dir) if filter.DenyPolicy == sdk.DenyPolicyHide { if !filter.CheckAllowed(dirName) { continue } } vdirs[dirName] = true } validIdx := 0 for idx := range dirContents { fi := dirContents[idx] if fi.Name() != "." && fi.Name() != ".." { if _, ok := vdirs[fi.Name()]; ok { continue } if filter.DenyPolicy == sdk.DenyPolicyHide { if !filter.CheckAllowed(fi.Name()) { continue } } } dirContents[validIdx] = fi validIdx++ } return dirContents[:validIdx] } // IsMappedPath returns true if the specified filesystem path has a virtual folder mapping. // The filesystem path must be cleaned before calling this method func (u *User) IsMappedPath(fsPath string) bool { for idx := range u.VirtualFolders { v := &u.VirtualFolders[idx] if fsPath == v.MappedPath { return true } } return false } // IsVirtualFolder returns true if the specified virtual path is a virtual folder func (u *User) IsVirtualFolder(virtualPath string) bool { for idx := range u.VirtualFolders { v := &u.VirtualFolders[idx] if virtualPath == v.VirtualPath { return true } } return false } // HasVirtualFoldersInside returns true if there are virtual folders inside the // specified virtual path. We assume that path are cleaned func (u *User) HasVirtualFoldersInside(virtualPath string) bool { if virtualPath == "/" && len(u.VirtualFolders) > 0 { return true } for idx := range u.VirtualFolders { v := &u.VirtualFolders[idx] if len(v.VirtualPath) > len(virtualPath) { if strings.HasPrefix(v.VirtualPath, virtualPath+"/") { return true } } } return false } // HasPermissionsInside returns true if the specified virtualPath has no permissions itself and // no subdirs with defined permissions func (u *User) HasPermissionsInside(virtualPath string) bool { for dir, perms := range u.Permissions { if len(perms) == 1 && perms[0] == PermAny { continue } if dir == virtualPath { return true } else if len(dir) > len(virtualPath) { if strings.HasPrefix(dir, virtualPath+"/") { return true } } } return false } // HasPerm returns true if the user has the given permission or any permission func (u *User) HasPerm(permission, path string) bool { perms := u.GetPermissionsForPath(path) if slices.Contains(perms, PermAny) { return true } return slices.Contains(perms, permission) } // HasAnyPerm returns true if the user has at least one of the given permissions func (u *User) HasAnyPerm(permissions []string, path string) bool { perms := u.GetPermissionsForPath(path) if slices.Contains(perms, PermAny) { return true } for _, permission := range permissions { if slices.Contains(perms, permission) { return true } } return false } // HasPerms returns true if the user has all the given permissions func (u *User) HasPerms(permissions []string, path string) bool { perms := u.GetPermissionsForPath(path) if slices.Contains(perms, PermAny) { return true } for _, permission := range permissions { if !slices.Contains(perms, permission) { return false } } return true } // HasPermsDeleteAll returns true if the user can delete both files and directories // for the given path func (u *User) HasPermsDeleteAll(path string) bool { perms := u.GetPermissionsForPath(path) canDeleteFiles := false canDeleteDirs := false for _, permission := range perms { if permission == PermAny || permission == PermDelete { return true } if permission == PermDeleteFiles { canDeleteFiles = true } if permission == PermDeleteDirs { canDeleteDirs = true } } return canDeleteFiles && canDeleteDirs } // HasPermsRenameAll returns true if the user can rename both files and directories // for the given path func (u *User) HasPermsRenameAll(path string) bool { perms := u.GetPermissionsForPath(path) canRenameFiles := false canRenameDirs := false for _, permission := range perms { if permission == PermAny || permission == PermRename { return true } if permission == PermRenameFiles { canRenameFiles = true } if permission == PermRenameDirs { canRenameDirs = true } } return canRenameFiles && canRenameDirs } // HasNoQuotaRestrictions returns true if no quota restrictions need to be applyed func (u *User) HasNoQuotaRestrictions(checkFiles bool) bool { if u.QuotaSize == 0 && (!checkFiles || u.QuotaFiles == 0) { return true } return false } // IsLoginMethodAllowed returns true if the specified login method is allowed func (u *User) IsLoginMethodAllowed(loginMethod, protocol string) bool { if len(u.Filters.DeniedLoginMethods) == 0 { return true } if slices.Contains(u.Filters.DeniedLoginMethods, loginMethod) { return false } if protocol == protocolSSH && loginMethod == LoginMethodPassword { if slices.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) { return false } } return true } // GetNextAuthMethods returns the list of authentications methods that can // continue for multi-step authentication. We call this method after a // successful public key authentication. func (u *User) GetNextAuthMethods() []string { var methods []string for _, method := range u.GetAllowedLoginMethods() { if method == SSHLoginMethodKeyAndPassword { methods = append(methods, LoginMethodPassword) } if method == SSHLoginMethodKeyAndKeyboardInt { methods = append(methods, SSHLoginMethodKeyboardInteractive) } } return methods } // IsPartialAuth returns true if the specified login method is a step for // a multi-step Authentication. // We support publickey+password and publickey+keyboard-interactive, so // only publickey can returns partial success. // We can have partial success if only multi-step Auth methods are enabled func (u *User) IsPartialAuth() bool { for _, method := range u.GetAllowedLoginMethods() { if method == LoginMethodTLSCertificate || method == LoginMethodTLSCertificateAndPwd || method == SSHLoginMethodPassword { continue } if method == LoginMethodPassword && slices.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) { continue } if !slices.Contains(SSHMultiStepsLoginMethods, method) { return false } } return true } // GetAllowedLoginMethods returns the allowed login methods func (u *User) GetAllowedLoginMethods() []string { var allowedMethods []string for _, method := range ValidLoginMethods { if method == SSHLoginMethodPassword { continue } if !slices.Contains(u.Filters.DeniedLoginMethods, method) { allowedMethods = append(allowedMethods, method) } } return allowedMethods } func (u *User) getPatternsFilterForPath(virtualPath string) sdk.PatternsFilter { var filter sdk.PatternsFilter if len(u.Filters.FilePatterns) == 0 { return filter } dirsForPath := util.GetDirsForVirtualPath(virtualPath) for idx, dir := range dirsForPath { for _, f := range u.Filters.FilePatterns { if f.Path == dir { if idx > 0 && len(f.AllowedPatterns) > 0 && len(f.DeniedPatterns) > 0 && f.DeniedPatterns[0] == "*" { if f.CheckAllowed(path.Base(dirsForPath[idx-1])) { return filter } } filter = f break } } if filter.Path != "" { break } } return filter } func (u *User) isDirHidden(virtualPath string) bool { if len(u.Filters.FilePatterns) == 0 { return false } for _, dirPath := range util.GetDirsForVirtualPath(virtualPath) { if dirPath == "/" { return false } filter := u.getPatternsFilterForPath(dirPath) if filter.DenyPolicy == sdk.DenyPolicyHide && filter.Path != dirPath { if !filter.CheckAllowed(path.Base(dirPath)) { return true } } } return false } func (u *User) getMinPasswordEntropy() float64 { if u.Filters.PasswordStrength > 0 { return float64(u.Filters.PasswordStrength) } return config.PasswordValidation.Users.MinEntropy } // IsFileAllowed returns true if the specified file is allowed by the file restrictions filters. // The second parameter returned is the deny policy func (u *User) IsFileAllowed(virtualPath string) (bool, int) { dirPath := path.Dir(virtualPath) if u.isDirHidden(dirPath) { return false, sdk.DenyPolicyHide } filter := u.getPatternsFilterForPath(dirPath) return filter.CheckAllowed(path.Base(virtualPath)), filter.DenyPolicy } // CanManageMFA returns true if the user can add a multi-factor authentication configuration func (u *User) CanManageMFA() bool { if slices.Contains(u.Filters.WebClient, sdk.WebClientMFADisabled) { return false } return len(mfa.GetAvailableTOTPConfigs()) > 0 } func (u *User) skipExternalAuth() bool { if u.Filters.Hooks.ExternalAuthDisabled { return true } if u.ID <= 0 { return false } if u.Filters.ExternalAuthCacheTime <= 0 { return false } return isLastActivityRecent(u.LastLogin, time.Duration(u.Filters.ExternalAuthCacheTime)*time.Second) } // CanManageShares returns true if the user can add, update and list shares func (u *User) CanManageShares() bool { return !slices.Contains(u.Filters.WebClient, sdk.WebClientSharesDisabled) } // CanResetPassword returns true if this user is allowed to reset its password func (u *User) CanResetPassword() bool { return !slices.Contains(u.Filters.WebClient, sdk.WebClientPasswordResetDisabled) } // CanChangePassword returns true if this user is allowed to change its password func (u *User) CanChangePassword() bool { return !slices.Contains(u.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) } // CanChangeAPIKeyAuth returns true if this user is allowed to enable/disable API key authentication func (u *User) CanChangeAPIKeyAuth() bool { return !slices.Contains(u.Filters.WebClient, sdk.WebClientAPIKeyAuthChangeDisabled) } // CanChangeInfo returns true if this user is allowed to change its info such as email and description func (u *User) CanChangeInfo() bool { return !slices.Contains(u.Filters.WebClient, sdk.WebClientInfoChangeDisabled) } // CanManagePublicKeys returns true if this user is allowed to manage public keys // from the WebClient. Used in WebClient UI func (u *User) CanManagePublicKeys() bool { return !slices.Contains(u.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled) } // CanManageTLSCerts returns true if this user is allowed to manage TLS certificates // from the WebClient. Used in WebClient UI func (u *User) CanManageTLSCerts() bool { return !slices.Contains(u.Filters.WebClient, sdk.WebClientTLSCertChangeDisabled) } // CanUpdateProfile returns true if the user is allowed to update the profile. // Used in WebClient UI func (u *User) CanUpdateProfile() bool { return u.CanManagePublicKeys() || u.CanChangeAPIKeyAuth() || u.CanChangeInfo() || u.CanManageTLSCerts() } // CanAddFilesFromWeb returns true if the client can add files from the web UI. // The specified target is the directory where the files must be uploaded func (u *User) CanAddFilesFromWeb(target string) bool { if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } return u.HasPerm(PermUpload, target) || u.HasPerm(PermOverwrite, target) } // CanAddDirsFromWeb returns true if the client can add directories from the web UI. // The specified target is the directory where the new directory must be created func (u *User) CanAddDirsFromWeb(target string) bool { if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } return u.HasPerm(PermCreateDirs, target) } // CanRenameFromWeb returns true if the client can rename objects from the web UI. // The specified src and dest are the source and target directories for the rename. func (u *User) CanRenameFromWeb(src, dest string) bool { if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } return u.HasAnyPerm(permsRenameAny, src) && u.HasAnyPerm(permsRenameAny, dest) } // CanDeleteFromWeb returns true if the client can delete objects from the web UI. // The specified target is the parent directory for the object to delete func (u *User) CanDeleteFromWeb(target string) bool { if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } return u.HasAnyPerm(permsDeleteAny, target) } // CanCopyFromWeb returns true if the client can copy objects from the web UI. // The specified src and dest are the source and target directories for the copy. func (u *User) CanCopyFromWeb(src, dest string) bool { if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } if !u.HasPerm(PermListItems, src) { return false } if !u.HasPerm(PermDownload, src) { return false } return u.HasPerm(PermCopy, src) && u.HasPerm(PermCopy, dest) } // InactivityDays returns the number of days of inactivity func (u *User) InactivityDays(when time.Time) int { if when.IsZero() { when = time.Now() } lastActivity := u.LastLogin if lastActivity == 0 { lastActivity = u.CreatedAt } if lastActivity == 0 { // unable to determine inactivity return 0 } return int(float64(when.Sub(util.GetTimeFromMsecSinceEpoch(lastActivity))) / float64(24*time.Hour)) } // PasswordExpiresIn returns the number of days before the password expires. // The returned value is negative if the password is expired. // The caller must ensure that a PasswordExpiration is set func (u *User) PasswordExpiresIn() int { lastPwdChange := util.GetTimeFromMsecSinceEpoch(u.LastPasswordChange) pwdExpiration := lastPwdChange.Add(time.Duration(u.Filters.PasswordExpiration) * 24 * time.Hour) res := int(math.Round(float64(time.Until(pwdExpiration)) / float64(24*time.Hour))) if res == 0 && pwdExpiration.After(time.Now()) { res = 1 } return res } // MustChangePassword returns true if the user must change the password func (u *User) MustChangePassword() bool { if u.Filters.RequirePasswordChange { return true } if u.Filters.PasswordExpiration == 0 { return false } lastPwdChange := util.GetTimeFromMsecSinceEpoch(u.LastPasswordChange) return lastPwdChange.Add(time.Duration(u.Filters.PasswordExpiration) * 24 * time.Hour).Before(time.Now()) } // MustSetSecondFactor returns true if the user must set a second factor authentication func (u *User) MustSetSecondFactor() bool { if len(u.Filters.TwoFactorAuthProtocols) > 0 { if !u.Filters.TOTPConfig.Enabled { return true } for _, p := range u.Filters.TwoFactorAuthProtocols { if !slices.Contains(u.Filters.TOTPConfig.Protocols, p) { return true } } } return false } // MustSetSecondFactorForProtocol returns true if the user must set a second factor authentication // for the specified protocol func (u *User) MustSetSecondFactorForProtocol(protocol string) bool { if slices.Contains(u.Filters.TwoFactorAuthProtocols, protocol) { if !u.Filters.TOTPConfig.Enabled { return true } if !slices.Contains(u.Filters.TOTPConfig.Protocols, protocol) { return true } } return false } // GetSignature returns a signature for this user. // It will change after an update func (u *User) GetSignature() string { return strconv.FormatInt(u.UpdatedAt, 10) } // GetBandwidthForIP returns the upload and download bandwidth for the specified IP func (u *User) GetBandwidthForIP(clientIP, connectionID string) (int64, int64) { if len(u.Filters.BandwidthLimits) > 0 { ip := net.ParseIP(clientIP) if ip != nil { for _, bwLimit := range u.Filters.BandwidthLimits { for _, source := range bwLimit.Sources { _, ipNet, err := net.ParseCIDR(source) if err == nil { if ipNet.Contains(ip) { logger.Debug(logSender, connectionID, "override bandwidth limit for ip %q, upload limit: %v KB/s, download limit: %v KB/s", clientIP, bwLimit.UploadBandwidth, bwLimit.DownloadBandwidth) return bwLimit.UploadBandwidth, bwLimit.DownloadBandwidth } } } } } } return u.UploadBandwidth, u.DownloadBandwidth } // IsLoginFromAddrAllowed returns true if the login is allowed from the specified remoteAddr. // If AllowedIP is defined only the specified IP/Mask can login. // If DeniedIP is defined the specified IP/Mask cannot login. // If an IP is both allowed and denied then login will be allowed func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool { if len(u.Filters.AllowedIP) == 0 && len(u.Filters.DeniedIP) == 0 { return true } remoteIP := net.ParseIP(util.GetIPFromRemoteAddress(remoteAddr)) // if remoteIP is invalid we allow login, this should never happen if remoteIP == nil { logger.Warn(logSender, "", "login allowed for invalid IP. remote address: %q", remoteAddr) return true } for _, IPMask := range u.Filters.AllowedIP { _, IPNet, err := net.ParseCIDR(IPMask) if err != nil { return false } if IPNet.Contains(remoteIP) { return true } } for _, IPMask := range u.Filters.DeniedIP { _, IPNet, err := net.ParseCIDR(IPMask) if err != nil { return false } if IPNet.Contains(remoteIP) { return false } } return len(u.Filters.AllowedIP) == 0 } // GetPermissionsAsJSON returns the permissions as json byte array func (u *User) GetPermissionsAsJSON() ([]byte, error) { return json.Marshal(u.Permissions) } // GetPublicKeysAsJSON returns the public keys as json byte array func (u *User) GetPublicKeysAsJSON() ([]byte, error) { return json.Marshal(u.PublicKeys) } // GetFiltersAsJSON returns the filters as json byte array func (u *User) GetFiltersAsJSON() ([]byte, error) { return json.Marshal(u.Filters) } // GetFsConfigAsJSON returns the filesystem config as json byte array func (u *User) GetFsConfigAsJSON() ([]byte, error) { return json.Marshal(u.FsConfig) } // GetUID returns a validate uid, suitable for use with os.Chown func (u *User) GetUID() int { if u.UID <= 0 || u.UID > math.MaxInt32 { return -1 } return u.UID } // GetGID returns a validate gid, suitable for use with os.Chown func (u *User) GetGID() int { if u.GID <= 0 || u.GID > math.MaxInt32 { return -1 } return u.GID } // GetHomeDir returns the shortest path name equivalent to the user's home directory func (u *User) GetHomeDir() string { return u.HomeDir } // HasRecentActivity returns true if the last user login is recent and so we can skip some expensive checks func (u *User) HasRecentActivity() bool { return isLastActivityRecent(u.LastLogin, lastLoginMinDelay) } // HasQuotaRestrictions returns true if there are any disk quota restrictions func (u *User) HasQuotaRestrictions() bool { return u.QuotaFiles > 0 || u.QuotaSize > 0 } // HasTransferQuotaRestrictions returns true if there are any data transfer restrictions func (u *User) HasTransferQuotaRestrictions() bool { return u.UploadDataTransfer > 0 || u.TotalDataTransfer > 0 || u.DownloadDataTransfer > 0 } // GetDataTransferLimits returns upload, download and total data transfer limits func (u *User) GetDataTransferLimits() (int64, int64, int64) { var total, ul, dl int64 if u.TotalDataTransfer > 0 { total = u.TotalDataTransfer * 1048576 } if u.DownloadDataTransfer > 0 { dl = u.DownloadDataTransfer * 1048576 } if u.UploadDataTransfer > 0 { ul = u.UploadDataTransfer * 1048576 } return ul, dl, total } // GetAllowedIPAsString returns the allowed IP as comma separated string func (u *User) GetAllowedIPAsString() string { return strings.Join(u.Filters.AllowedIP, ",") } // GetDeniedIPAsString returns the denied IP as comma separated string func (u *User) GetDeniedIPAsString() string { return strings.Join(u.Filters.DeniedIP, ",") } // HasExternalAuth returns true if the external authentication is globally enabled // and it is not disabled for this user func (u *User) HasExternalAuth() bool { if u.Filters.Hooks.ExternalAuthDisabled { return false } if config.ExternalAuthHook != "" { return true } return plugin.Handler.HasAuthenticators() } // CountUnusedRecoveryCodes returns the number of unused recovery codes func (u *User) CountUnusedRecoveryCodes() int { unused := 0 for _, code := range u.Filters.RecoveryCodes { if !code.Used { unused++ } } return unused } // SetEmptySecretsIfNil sets the secrets to empty if nil func (u *User) SetEmptySecretsIfNil() { u.HasPassword = u.Password != "" u.FsConfig.SetEmptySecretsIfNil() for idx := range u.VirtualFolders { vfolder := &u.VirtualFolders[idx] vfolder.FsConfig.SetEmptySecretsIfNil() } if u.Filters.TOTPConfig.Secret == nil { u.Filters.TOTPConfig.Secret = kms.NewEmptySecret() } } func (u *User) hasMainDataTransferLimits() bool { return u.UploadDataTransfer > 0 || u.DownloadDataTransfer > 0 || u.TotalDataTransfer > 0 } // HasPrimaryGroup returns true if the user has the specified primary group func (u *User) HasPrimaryGroup(name string) bool { for _, g := range u.Groups { if g.Name == name { return g.Type == sdk.GroupTypePrimary } } return false } // HasSecondaryGroup returns true if the user has the specified secondary group func (u *User) HasSecondaryGroup(name string) bool { for _, g := range u.Groups { if g.Name == name { return g.Type == sdk.GroupTypeSecondary } } return false } // HasMembershipGroup returns true if the user has the specified membership group func (u *User) HasMembershipGroup(name string) bool { for _, g := range u.Groups { if g.Name == name { return g.Type == sdk.GroupTypeMembership } } return false } func (u *User) hasSettingsFromGroups() bool { for _, g := range u.Groups { if g.Type != sdk.GroupTypeMembership { return true } } return false } func (u *User) applyGroupSettings(groupsMapping map[string]Group) { if !u.hasSettingsFromGroups() { return } if u.groupSettingsApplied { return } replacer := u.getGroupPlacehodersReplacer() for _, g := range u.Groups { if g.Type == sdk.GroupTypePrimary { if group, ok := groupsMapping[g.Name]; ok { u.mergeWithPrimaryGroup(&group, replacer) } else { providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name) } break } } for _, g := range u.Groups { if g.Type == sdk.GroupTypeSecondary { if group, ok := groupsMapping[g.Name]; ok { u.mergeAdditiveProperties(&group, sdk.GroupTypeSecondary, replacer) } else { providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name) } } } u.removeDuplicatesAfterGroupMerge() } // LoadAndApplyGroupSettings update the user by loading and applying the group settings func (u *User) LoadAndApplyGroupSettings() error { if !u.hasSettingsFromGroups() { return nil } if u.groupSettingsApplied { return nil } names := make([]string, 0, len(u.Groups)) var primaryGroupName string for _, g := range u.Groups { if g.Type == sdk.GroupTypePrimary { primaryGroupName = g.Name } if g.Type != sdk.GroupTypeMembership { names = append(names, g.Name) } } groups, err := provider.getGroupsWithNames(names) if err != nil { return fmt.Errorf("unable to get groups: %w", err) } replacer := u.getGroupPlacehodersReplacer() // make sure to always merge with the primary group first for idx := range groups { g := groups[idx] if g.Name == primaryGroupName { u.mergeWithPrimaryGroup(&g, replacer) lastIdx := len(groups) - 1 groups[idx] = groups[lastIdx] groups = groups[:lastIdx] break } } for idx := range groups { g := groups[idx] u.mergeAdditiveProperties(&g, sdk.GroupTypeSecondary, replacer) } u.removeDuplicatesAfterGroupMerge() return nil } func (u *User) getGroupPlacehodersReplacer() *strings.Replacer { return strings.NewReplacer("%username%", u.Username, "%role%", u.Role) } func (u *User) replacePlaceholder(value string, replacer *strings.Replacer) string { if value == "" { return value } return replacer.Replace(value) } func (u *User) replaceFsConfigPlaceholders(fsConfig vfs.Filesystem, replacer *strings.Replacer) vfs.Filesystem { switch fsConfig.Provider { case sdk.S3FilesystemProvider: fsConfig.S3Config.KeyPrefix = u.replacePlaceholder(fsConfig.S3Config.KeyPrefix, replacer) case sdk.GCSFilesystemProvider: fsConfig.GCSConfig.KeyPrefix = u.replacePlaceholder(fsConfig.GCSConfig.KeyPrefix, replacer) case sdk.AzureBlobFilesystemProvider: fsConfig.AzBlobConfig.KeyPrefix = u.replacePlaceholder(fsConfig.AzBlobConfig.KeyPrefix, replacer) case sdk.SFTPFilesystemProvider: fsConfig.SFTPConfig.Username = u.replacePlaceholder(fsConfig.SFTPConfig.Username, replacer) fsConfig.SFTPConfig.Prefix = u.replacePlaceholder(fsConfig.SFTPConfig.Prefix, replacer) case sdk.HTTPFilesystemProvider: fsConfig.HTTPConfig.Username = u.replacePlaceholder(fsConfig.HTTPConfig.Username, replacer) } return fsConfig } func (u *User) mergeCryptFsConfig(group *Group) { if group.UserSettings.FsConfig.Provider == sdk.CryptedFilesystemProvider { if u.FsConfig.CryptConfig.ReadBufferSize == 0 { u.FsConfig.CryptConfig.ReadBufferSize = group.UserSettings.FsConfig.CryptConfig.ReadBufferSize } if u.FsConfig.CryptConfig.WriteBufferSize == 0 { u.FsConfig.CryptConfig.WriteBufferSize = group.UserSettings.FsConfig.CryptConfig.WriteBufferSize } } } func (u *User) mergeWithPrimaryGroup(group *Group, replacer *strings.Replacer) { if group.UserSettings.HomeDir != "" { u.HomeDir = filepath.Clean(u.replacePlaceholder(group.UserSettings.HomeDir, replacer)) } if group.UserSettings.FsConfig.Provider != 0 { u.FsConfig = u.replaceFsConfigPlaceholders(group.UserSettings.FsConfig, replacer) u.mergeCryptFsConfig(group) } else { if u.FsConfig.OSConfig.ReadBufferSize == 0 { u.FsConfig.OSConfig.ReadBufferSize = group.UserSettings.FsConfig.OSConfig.ReadBufferSize } if u.FsConfig.OSConfig.WriteBufferSize == 0 { u.FsConfig.OSConfig.WriteBufferSize = group.UserSettings.FsConfig.OSConfig.WriteBufferSize } } if u.MaxSessions == 0 { u.MaxSessions = group.UserSettings.MaxSessions } if u.QuotaSize == 0 { u.QuotaSize = group.UserSettings.QuotaSize } if u.QuotaFiles == 0 { u.QuotaFiles = group.UserSettings.QuotaFiles } if u.UploadBandwidth == 0 { u.UploadBandwidth = group.UserSettings.UploadBandwidth } if u.DownloadBandwidth == 0 { u.DownloadBandwidth = group.UserSettings.DownloadBandwidth } if !u.hasMainDataTransferLimits() { u.UploadDataTransfer = group.UserSettings.UploadDataTransfer u.DownloadDataTransfer = group.UserSettings.DownloadDataTransfer u.TotalDataTransfer = group.UserSettings.TotalDataTransfer } if u.ExpirationDate == 0 && group.UserSettings.ExpiresIn > 0 { u.ExpirationDate = u.CreatedAt + int64(group.UserSettings.ExpiresIn)*86400000 } u.mergePrimaryGroupFilters(&group.UserSettings.Filters, replacer) u.mergeAdditiveProperties(group, sdk.GroupTypePrimary, replacer) } func (u *User) mergePrimaryGroupFilters(filters *sdk.BaseUserFilters, replacer *strings.Replacer) { //nolint:gocyclo if u.Filters.MaxUploadFileSize == 0 { u.Filters.MaxUploadFileSize = filters.MaxUploadFileSize } if !u.IsTLSVerificationEnabled() { u.Filters.TLSUsername = filters.TLSUsername } if !u.Filters.Hooks.CheckPasswordDisabled { u.Filters.Hooks.CheckPasswordDisabled = filters.Hooks.CheckPasswordDisabled } if !u.Filters.Hooks.PreLoginDisabled { u.Filters.Hooks.PreLoginDisabled = filters.Hooks.PreLoginDisabled } if !u.Filters.Hooks.ExternalAuthDisabled { u.Filters.Hooks.ExternalAuthDisabled = filters.Hooks.ExternalAuthDisabled } if !u.Filters.DisableFsChecks { u.Filters.DisableFsChecks = filters.DisableFsChecks } if !u.Filters.AllowAPIKeyAuth { u.Filters.AllowAPIKeyAuth = filters.AllowAPIKeyAuth } if !u.Filters.IsAnonymous { u.Filters.IsAnonymous = filters.IsAnonymous } if u.Filters.ExternalAuthCacheTime == 0 { u.Filters.ExternalAuthCacheTime = filters.ExternalAuthCacheTime } if u.Filters.FTPSecurity == 0 { u.Filters.FTPSecurity = filters.FTPSecurity } if u.Filters.StartDirectory == "" { u.Filters.StartDirectory = u.replacePlaceholder(filters.StartDirectory, replacer) } if u.Filters.DefaultSharesExpiration == 0 { u.Filters.DefaultSharesExpiration = filters.DefaultSharesExpiration } if u.Filters.MaxSharesExpiration == 0 { u.Filters.MaxSharesExpiration = filters.MaxSharesExpiration } if u.Filters.PasswordExpiration == 0 { u.Filters.PasswordExpiration = filters.PasswordExpiration } if u.Filters.PasswordStrength == 0 { u.Filters.PasswordStrength = filters.PasswordStrength } } func (u *User) mergeAdditiveProperties(group *Group, groupType int, replacer *strings.Replacer) { u.mergeVirtualFolders(group, groupType, replacer) u.mergePermissions(group, groupType, replacer) u.mergeFilePatterns(group, groupType, replacer) u.Filters.BandwidthLimits = append(u.Filters.BandwidthLimits, group.UserSettings.Filters.BandwidthLimits...) u.Filters.AllowedIP = append(u.Filters.AllowedIP, group.UserSettings.Filters.AllowedIP...) u.Filters.DeniedIP = append(u.Filters.DeniedIP, group.UserSettings.Filters.DeniedIP...) u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, group.UserSettings.Filters.DeniedLoginMethods...) u.Filters.DeniedProtocols = append(u.Filters.DeniedProtocols, group.UserSettings.Filters.DeniedProtocols...) u.Filters.WebClient = append(u.Filters.WebClient, group.UserSettings.Filters.WebClient...) u.Filters.TwoFactorAuthProtocols = append(u.Filters.TwoFactorAuthProtocols, group.UserSettings.Filters.TwoFactorAuthProtocols...) u.Filters.AccessTime = append(u.Filters.AccessTime, group.UserSettings.Filters.AccessTime...) } func (u *User) mergeVirtualFolders(group *Group, groupType int, replacer *strings.Replacer) { if len(group.VirtualFolders) > 0 { folderPaths := make(map[string]bool) for _, folder := range u.VirtualFolders { folderPaths[folder.VirtualPath] = true } for _, folder := range group.VirtualFolders { if folder.VirtualPath == "/" && groupType != sdk.GroupTypePrimary { continue } folder.VirtualPath = u.replacePlaceholder(folder.VirtualPath, replacer) if _, ok := folderPaths[folder.VirtualPath]; !ok { folder.MappedPath = u.replacePlaceholder(folder.MappedPath, replacer) folder.FsConfig = u.replaceFsConfigPlaceholders(folder.FsConfig, replacer) u.VirtualFolders = append(u.VirtualFolders, folder) } } } } func (u *User) mergePermissions(group *Group, groupType int, replacer *strings.Replacer) { if u.Permissions == nil { u.Permissions = make(map[string][]string) } for k, v := range group.UserSettings.Permissions { if k == "/" { if groupType == sdk.GroupTypePrimary { u.Permissions[k] = v } else { continue } } k = u.replacePlaceholder(k, replacer) if _, ok := u.Permissions[k]; !ok { u.Permissions[k] = v } } } func (u *User) mergeFilePatterns(group *Group, groupType int, replacer *strings.Replacer) { if len(group.UserSettings.Filters.FilePatterns) > 0 { patternPaths := make(map[string]bool) for _, pattern := range u.Filters.FilePatterns { patternPaths[pattern.Path] = true } for _, pattern := range group.UserSettings.Filters.FilePatterns { if pattern.Path == "/" && groupType != sdk.GroupTypePrimary { continue } pattern.Path = u.replacePlaceholder(pattern.Path, replacer) if _, ok := patternPaths[pattern.Path]; !ok { u.Filters.FilePatterns = append(u.Filters.FilePatterns, pattern) } } } } func (u *User) removeDuplicatesAfterGroupMerge() { u.Filters.AllowedIP = util.RemoveDuplicates(u.Filters.AllowedIP, false) u.Filters.DeniedIP = util.RemoveDuplicates(u.Filters.DeniedIP, false) u.Filters.DeniedLoginMethods = util.RemoveDuplicates(u.Filters.DeniedLoginMethods, false) u.Filters.DeniedProtocols = util.RemoveDuplicates(u.Filters.DeniedProtocols, false) u.Filters.WebClient = util.RemoveDuplicates(u.Filters.WebClient, false) u.Filters.TwoFactorAuthProtocols = util.RemoveDuplicates(u.Filters.TwoFactorAuthProtocols, false) u.SetEmptySecretsIfNil() u.groupSettingsApplied = true } func (u *User) hasRole(role string) bool { if role == "" { return true } return role == u.Role } func (u *User) applyNamingRules() { u.Username = config.convertName(u.Username) u.Role = config.convertName(u.Role) for idx := range u.Groups { u.Groups[idx].Name = config.convertName(u.Groups[idx].Name) } for idx := range u.VirtualFolders { u.VirtualFolders[idx].Name = config.convertName(u.VirtualFolders[idx].Name) } } func (u *User) getACopy() User { u.SetEmptySecretsIfNil() pubKeys := make([]string, len(u.PublicKeys)) copy(pubKeys, u.PublicKeys) virtualFolders := make([]vfs.VirtualFolder, 0, len(u.VirtualFolders)) for idx := range u.VirtualFolders { vfolder := u.VirtualFolders[idx].GetACopy() virtualFolders = append(virtualFolders, vfolder) } groups := make([]sdk.GroupMapping, 0, len(u.Groups)) for _, g := range u.Groups { groups = append(groups, sdk.GroupMapping{ Name: g.Name, Type: g.Type, }) } permissions := make(map[string][]string) for k, v := range u.Permissions { perms := make([]string, len(v)) copy(perms, v) permissions[k] = perms } filters := UserFilters{ BaseUserFilters: copyBaseUserFilters(u.Filters.BaseUserFilters), } filters.RequirePasswordChange = u.Filters.RequirePasswordChange filters.TOTPConfig.Enabled = u.Filters.TOTPConfig.Enabled filters.TOTPConfig.ConfigName = u.Filters.TOTPConfig.ConfigName filters.TOTPConfig.Secret = u.Filters.TOTPConfig.Secret.Clone() filters.TOTPConfig.Protocols = make([]string, len(u.Filters.TOTPConfig.Protocols)) copy(filters.TOTPConfig.Protocols, u.Filters.TOTPConfig.Protocols) filters.AdditionalEmails = make([]string, len(u.Filters.AdditionalEmails)) copy(filters.AdditionalEmails, u.Filters.AdditionalEmails) filters.RecoveryCodes = make([]RecoveryCode, 0, len(u.Filters.RecoveryCodes)) for _, code := range u.Filters.RecoveryCodes { if code.Secret == nil { code.Secret = kms.NewEmptySecret() } filters.RecoveryCodes = append(filters.RecoveryCodes, RecoveryCode{ Secret: code.Secret.Clone(), Used: code.Used, }) } return User{ BaseUser: sdk.BaseUser{ ID: u.ID, Username: u.Username, Email: u.Email, Password: u.Password, PublicKeys: pubKeys, HasPassword: u.HasPassword, HomeDir: u.HomeDir, UID: u.UID, GID: u.GID, MaxSessions: u.MaxSessions, QuotaSize: u.QuotaSize, QuotaFiles: u.QuotaFiles, Permissions: permissions, UsedQuotaSize: u.UsedQuotaSize, UsedQuotaFiles: u.UsedQuotaFiles, LastQuotaUpdate: u.LastQuotaUpdate, UploadBandwidth: u.UploadBandwidth, DownloadBandwidth: u.DownloadBandwidth, UploadDataTransfer: u.UploadDataTransfer, DownloadDataTransfer: u.DownloadDataTransfer, TotalDataTransfer: u.TotalDataTransfer, UsedUploadDataTransfer: u.UsedUploadDataTransfer, UsedDownloadDataTransfer: u.UsedDownloadDataTransfer, Status: u.Status, ExpirationDate: u.ExpirationDate, LastLogin: u.LastLogin, FirstDownload: u.FirstDownload, FirstUpload: u.FirstUpload, LastPasswordChange: u.LastPasswordChange, AdditionalInfo: u.AdditionalInfo, Description: u.Description, CreatedAt: u.CreatedAt, UpdatedAt: u.UpdatedAt, Role: u.Role, }, Filters: filters, VirtualFolders: virtualFolders, Groups: groups, FsConfig: u.FsConfig.GetACopy(), groupSettingsApplied: u.groupSettingsApplied, } } // GetEncryptionAdditionalData returns the additional data to use for AEAD func (u *User) GetEncryptionAdditionalData() string { return u.Username } ================================================ FILE: internal/ftpd/cryptfs_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package ftpd_test import ( "crypto/sha256" "fmt" "hash" "io" "net/http" "os" "path" "path/filepath" "testing" "time" "github.com/minio/sio" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpdtest" "github.com/drakkan/sftpgo/v2/internal/kms" ) func TestBasicFTPHandlingCryptFs(t *testing.T) { u := getTestUserWithCryptFs() u.QuotaSize = 6553600 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { assert.Len(t, common.Connections.GetStats(""), 1) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) encryptedFileSize, err := getEncryptedFileSize(testFileSize) assert.NoError(t, err) expectedQuotaSize := encryptedFileSize expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client, 0) assert.Error(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) // overwrite an existing file err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) info, err := os.Stat(localDownloadPath) if assert.NoError(t, err) { assert.Equal(t, testFileSize, info.Size()) } list, err := client.List(".") if assert.NoError(t, err) { if assert.Len(t, list, 1) { assert.Equal(t, testFileSize, int64(list[0].Size)) } } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) err = client.Rename(testFileName, testFileName+"1") assert.NoError(t, err) err = client.Delete(testFileName) assert.Error(t, err) err = client.Delete(testFileName + "1") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize-encryptedFileSize, user.UsedQuotaSize) curDir, err := client.CurrentDir() if assert.NoError(t, err) { assert.Equal(t, "/", curDir) } testDir := "testDir" err = client.MakeDir(testDir) assert.NoError(t, err) err = client.ChangeDir(testDir) assert.NoError(t, err) curDir, err = client.CurrentDir() if assert.NoError(t, err) { assert.Equal(t, path.Join("/", testDir), curDir) } err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) size, err := client.FileSize(path.Join("/", testDir, testFileName)) assert.NoError(t, err) assert.Equal(t, testFileSize, size) err = client.ChangeDirToParent() assert.NoError(t, err) curDir, err = client.CurrentDir() if assert.NoError(t, err) { assert.Equal(t, "/", curDir) } err = client.Delete(path.Join("/", testDir, testFileName)) assert.NoError(t, err) err = client.Delete(testDir) assert.Error(t, err) err = client.RemoveDir(testDir) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestBufferedCryptFs(t *testing.T) { u := getTestUserWithCryptFs() u.FsConfig.CryptConfig.OSFsConfig = sdk.OSFsConfig{ ReadBufferSize: 1, WriteBufferSize: 1, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) // overwrite an existing file err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) info, err := os.Stat(localDownloadPath) if assert.NoError(t, err) { assert.Equal(t, testFileSize, info.Size()) } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestZeroBytesTransfersCryptFs(t *testing.T) { u := getTestUserWithCryptFs() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { testFileName := "testfilename" err = checkBasicFTP(client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, "emptydownload") err = os.WriteFile(localDownloadPath, []byte(""), os.ModePerm) assert.NoError(t, err) err = ftpUploadFile(localDownloadPath, testFileName, 0, client, 0) assert.NoError(t, err) size, err := client.FileSize(testFileName) assert.NoError(t, err) assert.Equal(t, int64(0), size) err = os.Remove(localDownloadPath) assert.NoError(t, err) assert.NoFileExists(t, localDownloadPath) err = ftpDownloadFile(testFileName, localDownloadPath, 0, client, 0) assert.NoError(t, err) info, err := os.Stat(localDownloadPath) if assert.NoError(t, err) { assert.Equal(t, int64(0), info.Size()) } err = client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestResumeCryptFs(t *testing.T) { u := getTestUserWithCryptFs() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) data := []byte("test data") err = os.WriteFile(testFilePath, data, os.ModePerm) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) assert.NoError(t, err) // resuming uploads is not supported err = ftpUploadFile(testFilePath, testFileName, int64(len(data)+5), client, 5) assert.Error(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, int64(4), client, 5) assert.NoError(t, err) readed, err := os.ReadFile(localDownloadPath) assert.NoError(t, err) assert.Equal(t, data[5:], readed) err = ftpDownloadFile(testFileName, localDownloadPath, int64(8), client, 1) assert.NoError(t, err) readed, err = os.ReadFile(localDownloadPath) assert.NoError(t, err) assert.Equal(t, data[1:], readed) err = ftpDownloadFile(testFileName, localDownloadPath, int64(0), client, 9) assert.NoError(t, err) err = client.Delete(testFileName) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) assert.NoError(t, err) // now append to a file srcFile, err := os.Open(testFilePath) if assert.NoError(t, err) { err = client.Append(testFileName, srcFile) assert.Error(t, err) err = srcFile.Close() assert.NoError(t, err) size, err := client.FileSize(testFileName) assert.NoError(t, err) assert.Equal(t, int64(len(data)), size) err = ftpDownloadFile(testFileName, localDownloadPath, int64(len(data)), client, 0) assert.NoError(t, err) readed, err = os.ReadFile(localDownloadPath) assert.NoError(t, err) assert.Equal(t, data, readed) } // now test a download resume using a bigger file testFileSize := int64(655352) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) initialHash, err := computeHashForFile(sha256.New(), testFilePath) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) downloadHash, err := computeHashForFile(sha256.New(), localDownloadPath) assert.NoError(t, err) assert.Equal(t, initialHash, downloadHash) err = os.Truncate(localDownloadPath, 32767) assert.NoError(t, err) err = ftpDownloadFile(testFileName, localDownloadPath+"_partial", testFileSize-32767, client, 32767) //nolint:goconst assert.NoError(t, err) file, err := os.OpenFile(localDownloadPath, os.O_APPEND|os.O_WRONLY, os.ModePerm) assert.NoError(t, err) file1, err := os.Open(localDownloadPath + "_partial") //nolint:goconst assert.NoError(t, err) _, err = io.Copy(file, file1) assert.NoError(t, err) err = file.Close() assert.NoError(t, err) err = file1.Close() assert.NoError(t, err) downloadHash, err = computeHashForFile(sha256.New(), localDownloadPath) assert.NoError(t, err) assert.Equal(t, initialHash, downloadHash) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = os.Remove(localDownloadPath + "_partial") assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func getTestUserWithCryptFs() dataprovider.User { user := getTestUser() user.FsConfig.Provider = sdk.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("testPassphrase") return user } func getEncryptedFileSize(size int64) (int64, error) { encSize, err := sio.EncryptedSize(uint64(size)) return int64(encSize) + 33, err } func computeHashForFile(hasher hash.Hash, path string) (string, error) { hash := "" f, err := os.Open(path) if err != nil { return hash, err } defer f.Close() _, err = io.Copy(hasher, f) if err == nil { hash = fmt.Sprintf("%x", hasher.Sum(nil)) } return hash, err } ================================================ FILE: internal/ftpd/ftpd.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package ftpd implements the FTP protocol package ftpd import ( "context" "errors" "fmt" "log/slog" "net" "os" "path/filepath" "strings" "time" ftpserver "github.com/fclairamb/ftpserverlib" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( logSender = "ftpd" ) var ( certMgr *common.CertManager serviceStatus ServiceStatus ) // PassiveIPOverride defines an exception for the configured passive IP type PassiveIPOverride struct { Networks []string `json:"networks" mapstructure:"networks"` // if empty the local address will be returned IP string `json:"ip" mapstructure:"ip"` parsedNetworks []func(net.IP) bool } // GetNetworksAsString returns the configured networks as string func (p *PassiveIPOverride) GetNetworksAsString() string { return strings.Join(p.Networks, ", ") } // Binding defines the configuration for a network listener type Binding struct { // The address to listen on. A blank value means listen on all available network interfaces. Address string `json:"address" mapstructure:"address"` // The port used for serving requests Port int `json:"port" mapstructure:"port"` // Apply the proxy configuration, if any, for this binding ApplyProxyConfig bool `json:"apply_proxy_config" mapstructure:"apply_proxy_config"` // Set to 1 to require TLS for both data and control connection. // Set to 2 to enable implicit TLS TLSMode int `json:"tls_mode" mapstructure:"tls_mode"` // Certificate and matching private key for this specific binding, if empty the global // ones will be used, if any CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` // Defines the minimum TLS version. 13 means TLS 1.3, default is TLS 1.2 MinTLSVersion int `json:"min_tls_version" mapstructure:"min_tls_version"` // External IP address for passive connections. ForcePassiveIP string `json:"force_passive_ip" mapstructure:"force_passive_ip"` // PassiveIPOverrides allows to define different IP addresses for passive connections // based on the client IP address PassiveIPOverrides []PassiveIPOverride `json:"passive_ip_overrides" mapstructure:"passive_ip_overrides"` // Hostname for passive connections. This hostname will be resolved each time a passive // connection is requested and this can, depending on the DNS configuration, take a noticeable // amount of time. Enable this setting only if you have a dynamic IP address PassiveHost string `json:"passive_host" mapstructure:"passive_host"` // Set to 1 to require client certificate authentication. // Set to 2 to require a client certificate and verfify it if given. In this mode // the client is allowed not to send a certificate. // You need to define at least a certificate authority for this to work ClientAuthType int `json:"client_auth_type" mapstructure:"client_auth_type"` // TLSCipherSuites is a list of supported cipher suites for TLS version 1.2. // If CipherSuites is nil/empty, a default list of secure cipher suites // is used, with a preference order based on hardware performance. // Note that TLS 1.3 ciphersuites are not configurable. // The supported ciphersuites names are defined here: // // https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53 // // any invalid name will be silently ignored. // The order matters, the ciphers listed first will be the preferred ones. TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` // PassiveConnectionsSecurity defines the security checks for passive data connections. // Supported values: // - 0 require matching peer IP addresses of control and data connection. This is the default // - 1 disable any checks PassiveConnectionsSecurity int `json:"passive_connections_security" mapstructure:"passive_connections_security"` // ActiveConnectionsSecurity defines the security checks for active data connections. // The supported values are the same as described for PassiveConnectionsSecurity. // Please note that disabling the security checks you will make the FTP service vulnerable to bounce attacks // on active data connections, so change the default value only if you are on a trusted/internal network ActiveConnectionsSecurity int `json:"active_connections_security" mapstructure:"active_connections_security"` // Debug enables the FTP debug mode. In debug mode, every FTP command will be logged Debug bool `json:"debug" mapstructure:"debug"` ciphers []uint16 } func (b *Binding) setCiphers() { b.ciphers = util.GetTLSCiphersFromNames(b.TLSCipherSuites) } func (b *Binding) isMutualTLSEnabled() bool { return b.ClientAuthType == 1 || b.ClientAuthType == 2 } // GetAddress returns the binding address func (b *Binding) GetAddress() string { return fmt.Sprintf("%s:%d", b.Address, b.Port) } // IsValid returns true if the binding port is > 0 func (b *Binding) IsValid() bool { return b.Port > 0 } func (b *Binding) isTLSModeValid() bool { return b.TLSMode >= 0 && b.TLSMode <= 2 } func (b *Binding) checkSecuritySettings() error { if b.PassiveConnectionsSecurity < 0 || b.PassiveConnectionsSecurity > 1 { return fmt.Errorf("invalid passive_connections_security: %v", b.PassiveConnectionsSecurity) } if b.ActiveConnectionsSecurity < 0 || b.ActiveConnectionsSecurity > 1 { return fmt.Errorf("invalid active_connections_security: %v", b.ActiveConnectionsSecurity) } return nil } func (b *Binding) checkPassiveIP() error { if b.ForcePassiveIP != "" { ip, err := parsePassiveIP(b.ForcePassiveIP) if err != nil { return err } b.ForcePassiveIP = ip } for idx, passiveOverride := range b.PassiveIPOverrides { var ip string if passiveOverride.IP != "" { var err error ip, err = parsePassiveIP(passiveOverride.IP) if err != nil { return err } } if len(passiveOverride.Networks) == 0 { return errors.New("passive IP networks override cannot be empty") } checkFuncs, err := util.ParseAllowedIPAndRanges(passiveOverride.Networks) if err != nil { return fmt.Errorf("invalid passive IP networks override %+v: %w", passiveOverride.Networks, err) } b.PassiveIPOverrides[idx].IP = ip b.PassiveIPOverrides[idx].parsedNetworks = checkFuncs } return nil } func (b *Binding) getPassiveIP(cc ftpserver.ClientContext) (string, error) { if b.ForcePassiveIP != "" { return b.ForcePassiveIP, nil } if b.PassiveHost != "" { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() addrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", b.PassiveHost) if err != nil { logger.Error(logSender, "", "unable to resolve hostname %q: %v", b.PassiveHost, err) return "", fmt.Errorf("unable to resolve hostname %q: %w", b.PassiveHost, err) } if len(addrs) > 0 { return addrs[0].String(), nil } } return strings.Split(cc.LocalAddr().String(), ":")[0], nil } func (b *Binding) passiveIPResolver(cc ftpserver.ClientContext) (string, error) { if len(b.PassiveIPOverrides) > 0 { clientIP := net.ParseIP(util.GetIPFromRemoteAddress(cc.RemoteAddr().String())) if clientIP != nil { for _, override := range b.PassiveIPOverrides { for _, fn := range override.parsedNetworks { if fn(clientIP) { if override.IP == "" { return strings.Split(cc.LocalAddr().String(), ":")[0], nil } return override.IP, nil } } } } } return b.getPassiveIP(cc) } // HasProxy returns true if the proxy protocol is active for this binding func (b *Binding) HasProxy() bool { return b.ApplyProxyConfig && common.Config.ProxyProtocol > 0 } // GetTLSDescription returns the TLS mode as string func (b *Binding) GetTLSDescription() string { if certMgr == nil { return util.I18nFTPTLSDisabled } switch b.TLSMode { case 1: return util.I18nFTPTLSExplicit case 2: return util.I18nFTPTLSImplicit } if certMgr.HasCertificate(common.DefaultTLSKeyPaidID) || certMgr.HasCertificate(b.GetAddress()) { return util.I18nFTPTLSMixed } return util.I18nFTPTLSDisabled } // PortRange defines a port range type PortRange struct { // Range start Start int `json:"start" mapstructure:"start"` // Range end End int `json:"end" mapstructure:"end"` } // ServiceStatus defines the service status type ServiceStatus struct { IsActive bool `json:"is_active"` Bindings []Binding `json:"bindings"` PassivePortRange PortRange `json:"passive_port_range"` } // Configuration defines the configuration for the ftp server type Configuration struct { // Addresses and ports to bind to Bindings []Binding `json:"bindings" mapstructure:"bindings"` // The contents of the specified file, if any, are diplayed when someone connects to the server. BannerFile string `json:"banner_file" mapstructure:"banner_file"` // If files containing a certificate and matching private key for the server are provided the server will accept // both plain FTP an explicit FTP over TLS. // Certificate and key files can be reloaded on demand sending a "SIGHUP" signal on Unix based systems and a // "paramchange" request to the running service on Windows. CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` // CACertificates defines the set of root certificate authorities to be used to verify client certificates. CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"` // CARevocationLists defines a set a revocation lists, one for each root CA, to be used to check // if a client certificate has been revoked CARevocationLists []string `json:"ca_revocation_lists" mapstructure:"ca_revocation_lists"` // Do not impose the port 20 for active data transfer. Enabling this option allows to run SFTPGo with less privilege ActiveTransfersPortNon20 bool `json:"active_transfers_port_non_20" mapstructure:"active_transfers_port_non_20"` // Set to true to disable active FTP DisableActiveMode bool `json:"disable_active_mode" mapstructure:"disable_active_mode"` // Set to true to enable the FTP SITE command. // We support chmod and symlink if SITE support is enabled EnableSite bool `json:"enable_site" mapstructure:"enable_site"` // Set to 1 to enable FTP commands that allow to calculate the hash value of files. // These FTP commands will be enabled: HASH, XCRC, MD5/XMD5, XSHA/XSHA1, XSHA256, XSHA512. // Please keep in mind that to calculate the hash we need to read the whole file, for // remote backends this means downloading the file, for the encrypted backend this means // decrypting the file HASHSupport int `json:"hash_support" mapstructure:"hash_support"` // Set to 1 to enable support for the non standard "COMB" FTP command. // Combine is only supported for local filesystem, for cloud backends it has // no advantage as it will download the partial files and will upload the // combined one. Cloud backends natively support multipart uploads. CombineSupport int `json:"combine_support" mapstructure:"combine_support"` // Port Range for data connections. Random if not specified PassivePortRange PortRange `json:"passive_port_range" mapstructure:"passive_port_range"` acmeDomain string } // ShouldBind returns true if there is at least a valid binding func (c *Configuration) ShouldBind() bool { for _, binding := range c.Bindings { if binding.IsValid() { return true } } return false } func (c *Configuration) getKeyPairs(configDir string) []common.TLSKeyPair { var keyPairs []common.TLSKeyPair for _, binding := range c.Bindings { certificateFile := getConfigPath(binding.CertificateFile, configDir) certificateKeyFile := getConfigPath(binding.CertificateKeyFile, configDir) if certificateFile != "" && certificateKeyFile != "" { keyPairs = append(keyPairs, common.TLSKeyPair{ Cert: certificateFile, Key: certificateKeyFile, ID: binding.GetAddress(), }) } } var certificateFile, certificateKeyFile string if c.acmeDomain != "" { certificateFile, certificateKeyFile = util.GetACMECertificateKeyPair(c.acmeDomain) } else { certificateFile = getConfigPath(c.CertificateFile, configDir) certificateKeyFile = getConfigPath(c.CertificateKeyFile, configDir) } if certificateFile != "" && certificateKeyFile != "" { keyPairs = append(keyPairs, common.TLSKeyPair{ Cert: certificateFile, Key: certificateKeyFile, ID: common.DefaultTLSKeyPaidID, }) } return keyPairs } func (c *Configuration) loadFromProvider() error { configs, err := dataprovider.GetConfigs() if err != nil { return fmt.Errorf("unable to load config from provider: %w", err) } configs.SetNilsToEmpty() if configs.ACME.Domain == "" || !configs.ACME.HasProtocol(common.ProtocolFTP) { return nil } crt, key := util.GetACMECertificateKeyPair(configs.ACME.Domain) if crt != "" && key != "" { if _, err := os.Stat(crt); err != nil { logger.Error(logSender, "", "unable to load acme cert file %q: %v", crt, err) return nil } if _, err := os.Stat(key); err != nil { logger.Error(logSender, "", "unable to load acme key file %q: %v", key, err) return nil } c.acmeDomain = configs.ACME.Domain logger.Info(logSender, "", "acme domain set to %q", c.acmeDomain) return nil } return nil } // Initialize configures and starts the FTP server func (c *Configuration) Initialize(configDir string) error { if err := c.loadFromProvider(); err != nil { return err } logger.Info(logSender, "", "initializing FTP server with config %+v", *c) if !c.ShouldBind() { return common.ErrNoBinding } keyPairs := c.getKeyPairs(configDir) if len(keyPairs) > 0 { mgr, err := common.NewCertManager(keyPairs, configDir, logSender) if err != nil { return err } mgr.SetCACertificates(c.CACertificates) if err := mgr.LoadRootCAs(); err != nil { return err } mgr.SetCARevocationLists(c.CARevocationLists) if err := mgr.LoadCRLs(); err != nil { return err } certMgr = mgr } serviceStatus = ServiceStatus{ Bindings: nil, PassivePortRange: c.PassivePortRange, } exitChannel := make(chan error, 1) for idx, binding := range c.Bindings { if !binding.IsValid() { continue } server := NewServer(c, configDir, binding, idx) go func(s *Server) { ftpLogger := logger.NewSlogAdapter("ftpserverlib", []slog.Attr{ { Key: "server_id", Value: slog.StringValue(fmt.Sprintf("FTP_%d", s.ID)), }, }) ftpServer := ftpserver.NewFtpServer(s) ftpServer.Logger = slog.New(ftpLogger) logger.Info(logSender, "", "starting FTP serving, binding: %v", s.binding.GetAddress()) util.CheckTCP4Port(s.binding.Port) exitChannel <- ftpServer.ListenAndServe() }(server) serviceStatus.Bindings = append(serviceStatus.Bindings, binding) } serviceStatus.IsActive = true return <-exitChannel } // ReloadCertificateMgr reloads the certificate manager func ReloadCertificateMgr() error { if certMgr != nil { return certMgr.Reload() } return nil } // GetStatus returns the server status func GetStatus() ServiceStatus { return serviceStatus } func parsePassiveIP(passiveIP string) (string, error) { ip := net.ParseIP(passiveIP) if ip == nil { return "", fmt.Errorf("the provided passive IP %q is not valid", passiveIP) } ip = ip.To4() if ip == nil { return "", fmt.Errorf("the provided passive IP %q is not a valid IPv4 address", passiveIP) } return ip.String(), nil } func getConfigPath(name, configDir string) string { if !util.IsFileInputValid(name) { return "" } if name != "" && !filepath.IsAbs(name) { return filepath.Join(configDir, name) } return name } ================================================ FILE: internal/ftpd/ftpd_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package ftpd_test import ( "crypto/rand" "crypto/sha256" "crypto/tls" "encoding/hex" "encoding/json" "errors" "fmt" "io" "io/fs" "net" "net/http" "os" "os/exec" "path" "path/filepath" "runtime" "strconv" "testing" "time" ftpserver "github.com/fclairamb/ftpserverlib" "github.com/jlaffaye/ftp" "github.com/pkg/sftp" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/rs/zerolog" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/httpdtest" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( logSender = "ftpdTesting" ftpServerAddr = "127.0.0.1:2121" sftpServerAddr = "127.0.0.1:2122" ftpSrvAddrTLS = "127.0.0.1:2124" // ftp server with implicit tls ftpSrvAddrTLSResumption = "127.0.0.1:2126" // ftp server with implicit tls defaultUsername = "test_user_ftp" defaultPassword = "test_password" osWindows = "windows" ftpsCert = `-----BEGIN CERTIFICATE----- MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM 3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY /8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI dV4vKmHUzwK/eIx+8Ay3neE= -----END CERTIFICATE-----` ftpsKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= -----END EC PRIVATE KEY-----` caCRT = `-----BEGIN CERTIFICATE----- MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U 4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY /tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w kK8u1YiiVenmoQ== -----END CERTIFICATE-----` caCRL = `-----BEGIN X509 CRL----- MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV +jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn w7bXJGfJcIMrrKs= -----END X509 CRL-----` client1Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au MU3Bo0A= -----END CERTIFICATE-----` client1Key = `-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi 85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== -----END RSA PRIVATE KEY-----` // client 2 crt is revoked client2Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB 4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO 78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv /ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR E0+r2+4= -----END CERTIFICATE-----` client2Key = `-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu 4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl /owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz 3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D 8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl -----END RSA PRIVATE KEY-----` testFileName = "test_file_ftp.dat" testDLFileName = "test_download_ftp.dat" tlsClient1Username = "client1" tlsClient2Username = "client2" httpFsPort = 23456 defaultHTTPFsUsername = "httpfs_ftp_user" emptyPwdPlaceholder = "empty" ) var ( configDir = filepath.Join(".", "..", "..") allPerms = []string{dataprovider.PermAny} homeBasePath string hookCmdPath string extAuthPath string preLoginPath string postConnectPath string preDownloadPath string preUploadPath string logFilePath string caCrtPath string caCRLPath string ) func TestMain(m *testing.M) { //nolint:gocyclo logFilePath = filepath.Join(configDir, "sftpgo_ftpd_test.log") bannerFileName := "banner_file" bannerFile := filepath.Join(configDir, bannerFileName) logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel) err := os.WriteFile(bannerFile, []byte("SFTPGo test ready\nsimple banner line\n"), os.ModePerm) if err != nil { logger.ErrorToConsole("error creating banner file: %v", err) os.Exit(1) } // we run the test cases with UploadMode atomic and resume support. The non atomic code path // simply does not execute some code so if it works in atomic mode will // work in non atomic mode too os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2") os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") err = config.LoadConfig(configDir, "") if err != nil { logger.ErrorToConsole("error loading configuration: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() logger.InfoToConsole("Starting FTPD tests, provider: %v", providerConf.Driver) commonConf := config.GetCommonConfig() homeBasePath = os.TempDir() if runtime.GOOS != osWindows { commonConf.Actions.ExecuteOn = []string{"download", "upload", "rename", "delete"} commonConf.Actions.Hook = hookCmdPath hookCmdPath, err = exec.LookPath("true") if err != nil { logger.Warn(logSender, "", "unable to get hook command: %v", err) logger.WarnToConsole("unable to get hook command: %v", err) } } certPath := filepath.Join(os.TempDir(), "test_ftpd.crt") keyPath := filepath.Join(os.TempDir(), "test_ftpd.key") caCrtPath = filepath.Join(os.TempDir(), "test_ftpd_ca.crt") caCRLPath = filepath.Join(os.TempDir(), "test_ftpd_crl.crt") err = writeCerts(certPath, keyPath, caCrtPath, caCRLPath) if err != nil { os.Exit(1) } err = dataprovider.Initialize(providerConf, configDir, true) if err != nil { logger.ErrorToConsole("error initializing data provider: %v", err) os.Exit(1) } err = common.Initialize(commonConf, 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) } httpConfig := config.GetHTTPConfig() httpConfig.Initialize(configDir) //nolint:errcheck kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing kms: %v", err) os.Exit(1) } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing MFA: %v", err) os.Exit(1) } httpdConf := config.GetHTTPDConfig() httpdConf.Bindings[0].Port = 8079 httpdtest.SetBaseURL("http://127.0.0.1:8079") ftpdConf := config.GetFTPDConfig() ftpdConf.Bindings = []ftpd.Binding{ { Port: 2121, ClientAuthType: 2, CertificateFile: certPath, CertificateKeyFile: keyPath, }, } ftpdConf.PassivePortRange.Start = 0 ftpdConf.PassivePortRange.End = 0 ftpdConf.BannerFile = bannerFileName ftpdConf.CACertificates = []string{caCrtPath} ftpdConf.CARevocationLists = []string{caCRLPath} ftpdConf.EnableSite = true // required to test sftpfs sftpdConf := config.GetSFTPDConfig() sftpdConf.Bindings = []sftpd.Binding{ { Port: 2122, }, } hostKeyPath := filepath.Join(os.TempDir(), "id_ed25519") sftpdConf.HostKeys = []string{hostKeyPath} extAuthPath = filepath.Join(homeBasePath, "extauth.sh") preLoginPath = filepath.Join(homeBasePath, "prelogin.sh") postConnectPath = filepath.Join(homeBasePath, "postconnect.sh") preDownloadPath = filepath.Join(homeBasePath, "predownload.sh") preUploadPath = filepath.Join(homeBasePath, "preupload.sh") status := ftpd.GetStatus() if status.IsActive { logger.ErrorToConsole("ftpd is already active") os.Exit(1) } go func() { logger.Debug(logSender, "", "initializing FTP server with config %+v", ftpdConf) if err := ftpdConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start FTP server: %v", err) os.Exit(1) } }() go func() { logger.Debug(logSender, "", "initializing SFTP server with config %+v", sftpdConf) if err := sftpdConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start SFTP server: %v", err) os.Exit(1) } }() go func() { if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } }() waitTCPListening(ftpdConf.Bindings[0].GetAddress()) waitTCPListening(httpdConf.Bindings[0].GetAddress()) waitTCPListening(sftpdConf.Bindings[0].GetAddress()) ftpd.ReloadCertificateMgr() //nolint:errcheck ftpdConf = config.GetFTPDConfig() ftpdConf.Bindings = []ftpd.Binding{ { Port: 2124, TLSMode: 2, }, } ftpdConf.CertificateFile = certPath ftpdConf.CertificateKeyFile = keyPath ftpdConf.CACertificates = []string{caCrtPath} ftpdConf.CARevocationLists = []string{caCRLPath} ftpdConf.EnableSite = false ftpdConf.DisableActiveMode = true ftpdConf.CombineSupport = 1 ftpdConf.HASHSupport = 1 go func() { logger.Debug(logSender, "", "initializing FTP server with config %+v", ftpdConf) if err := ftpdConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start FTP server: %v", err) os.Exit(1) } }() waitTCPListening(ftpdConf.Bindings[0].GetAddress()) ftpdConf = config.GetFTPDConfig() ftpdConf.Bindings = []ftpd.Binding{ { Port: 2126, CertificateFile: certPath, CertificateKeyFile: keyPath, TLSMode: 1, ClientAuthType: 2, }, } ftpdConf.CACertificates = []string{caCrtPath} ftpdConf.CARevocationLists = []string{caCRLPath} go func() { logger.Debug(logSender, "", "initializing FTP server with config %+v", ftpdConf) if err := ftpdConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start FTP server: %v", err) os.Exit(1) } }() waitTCPListening(ftpdConf.Bindings[0].GetAddress()) waitNoConnections() startHTTPFs() exitCode := m.Run() os.Remove(logFilePath) os.Remove(bannerFile) os.Remove(extAuthPath) os.Remove(preLoginPath) os.Remove(postConnectPath) os.Remove(preDownloadPath) os.Remove(preUploadPath) os.Remove(certPath) os.Remove(keyPath) os.Remove(caCrtPath) os.Remove(caCRLPath) os.Remove(hostKeyPath) os.Remove(hostKeyPath + ".pub") os.Exit(exitCode) } func TestInitializationFailure(t *testing.T) { ftpdConf := config.GetFTPDConfig() ftpdConf.Bindings = []ftpd.Binding{} ftpdConf.CertificateFile = filepath.Join(os.TempDir(), "test_ftpd.crt") ftpdConf.CertificateKeyFile = filepath.Join(os.TempDir(), "test_ftpd.key") err := ftpdConf.Initialize(configDir) require.EqualError(t, err, common.ErrNoBinding.Error()) ftpdConf.Bindings = []ftpd.Binding{ { Port: 0, }, { Port: 2121, }, } ftpdConf.BannerFile = "a-missing-file" err = ftpdConf.Initialize(configDir) require.Error(t, err) ftpdConf.BannerFile = "" ftpdConf.Bindings[1].TLSMode = 10 err = ftpdConf.Initialize(configDir) require.Error(t, err) ftpdConf.CertificateFile = "" ftpdConf.CertificateKeyFile = "" ftpdConf.Bindings[1].TLSMode = 1 err = ftpdConf.Initialize(configDir) require.Error(t, err) certPath := filepath.Join(os.TempDir(), "test_ftpd.crt") keyPath := filepath.Join(os.TempDir(), "test_ftpd.key") ftpdConf.CertificateFile = certPath ftpdConf.CertificateKeyFile = keyPath ftpdConf.CACertificates = []string{"invalid ca cert"} err = ftpdConf.Initialize(configDir) require.Error(t, err) ftpdConf.CACertificates = nil ftpdConf.CARevocationLists = []string{""} err = ftpdConf.Initialize(configDir) require.Error(t, err) ftpdConf.CACertificates = []string{caCrtPath} ftpdConf.CARevocationLists = []string{caCRLPath} ftpdConf.Bindings[1].ForcePassiveIP = "127001" err = ftpdConf.Initialize(configDir) require.Error(t, err) require.Contains(t, err.Error(), "the provided passive IP \"127001\" is not valid") ftpdConf.Bindings[1].ForcePassiveIP = "" err = ftpdConf.Initialize(configDir) require.Error(t, err) err = dataprovider.Close() assert.NoError(t, err) err = ftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to load config from provider") } err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestBasicFTPHandling(t *testing.T) { u := getTestUser() u.QuotaSize = 6553600 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaSize = 6553600 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { if user.Username == defaultUsername { assert.Len(t, common.Connections.GetStats(""), 1) } else { assert.Len(t, common.Connections.GetStats(""), 2) } testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) expectedQuotaSize := testFileSize expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client, 0) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, int64(0), user.FirstDownload) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.FirstUpload, int64(0)) assert.Equal(t, int64(0), user.FirstDownload) // overwrite an existing file err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) assert.Greater(t, user.FirstUpload, int64(0)) assert.Greater(t, user.FirstDownload, int64(0)) err = client.Rename(testFileName, testFileName+"1") assert.NoError(t, err) err = client.Delete(testFileName) assert.Error(t, err) err = client.Delete(testFileName + "1") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize-testFileSize, user.UsedQuotaSize) curDir, err := client.CurrentDir() if assert.NoError(t, err) { assert.Equal(t, "/", curDir) } testDir := "testDir" err = client.MakeDir(testDir) assert.NoError(t, err) err = client.ChangeDir(testDir) assert.NoError(t, err) curDir, err = client.CurrentDir() if assert.NoError(t, err) { assert.Equal(t, path.Join("/", testDir), curDir) } res, err := client.List(path.Join("/", testDir)) assert.NoError(t, err) assert.Len(t, res, 0) res, err = client.List(path.Join("/")) assert.NoError(t, err) if assert.Len(t, res, 1) { assert.Equal(t, testDir, res[0].Name) } err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) _, err = client.FileSize(path.Join("/", testDir)) assert.Error(t, err) size, err := client.FileSize(path.Join("/", testDir, testFileName)) assert.NoError(t, err) assert.Equal(t, testFileSize, size) err = client.ChangeDirToParent() assert.NoError(t, err) curDir, err = client.CurrentDir() if assert.NoError(t, err) { assert.Equal(t, "/", curDir) } err = client.Delete(path.Join("/", testDir, testFileName)) assert.NoError(t, err) err = client.Delete(testDir) assert.Error(t, err) err = client.RemoveDir(testDir) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestHTTPFs(t *testing.T) { u := getTestUserWithHTTPFs() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) // test a download resume data := []byte("test data") err = os.WriteFile(testFilePath, data, os.ModePerm) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) assert.NoError(t, err) err = ftpDownloadFile(testFileName, localDownloadPath, int64(len(data)-5), client, 5) assert.NoError(t, err) readed, err := os.ReadFile(localDownloadPath) assert.NoError(t, err) assert.Equal(t, []byte("data"), readed, "readed data mismatch: %q", string(readed)) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestListDirWithWildcards(t *testing.T) { localUser, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) defer func() { _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) }() for _, user := range []dataprovider.User{localUser, sftpUser} { client, err := getFTPClient(user, true, nil, ftp.DialWithDisabledMLSD(true)) if assert.NoError(t, err) { dir1 := "test.dir" dir2 := "test.dir1" err = client.MakeDir(dir1) assert.NoError(t, err) err = client.MakeDir(dir2) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) fileName := "file[a-z]e.dat" err = ftpUploadFile(testFilePath, fileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(fileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) entries, err := client.List(fileName) require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, fileName, entries[0].Name) nListEntries, err := client.NameList(fileName) require.NoError(t, err) require.Len(t, entries, 1) assert.Contains(t, nListEntries, fileName) entries, err = client.List(".") require.NoError(t, err) require.Len(t, entries, 3) nListEntries, err = client.NameList(".") require.NoError(t, err) require.Len(t, nListEntries, 3) entries, err = client.List("/test.*") require.NoError(t, err) require.Len(t, entries, 2) found := 0 for _, e := range entries { switch e.Name { case dir1, dir2: found++ } } assert.Equal(t, 2, found) nListEntries, err = client.NameList("/test.*") require.NoError(t, err) require.Len(t, entries, 2) assert.Contains(t, nListEntries, dir1) assert.Contains(t, nListEntries, dir2) entries, err = client.List("/*.dir?") require.NoError(t, err) assert.Len(t, entries, 1) assert.Equal(t, dir2, entries[0].Name) nListEntries, err = client.NameList("/*.dir?") require.NoError(t, err) require.Len(t, entries, 1) assert.Contains(t, nListEntries, dir2) entries, err = client.List("/test.???") require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, dir1, entries[0].Name) nListEntries, err = client.NameList("/test.???") require.NoError(t, err) require.Len(t, entries, 1) assert.Contains(t, nListEntries, dir1) _, err = client.NameList("/missingdir/test.*") assert.Error(t, err) _, err = client.List("/missingdir/test.*") assert.Error(t, err) _, err = client.NameList("test[-]") if assert.Error(t, err) { assert.Contains(t, err.Error(), path.ErrBadPattern.Error()) } _, err = client.List("test[-]") if assert.Error(t, err) { assert.Contains(t, err.Error(), path.ErrBadPattern.Error()) } subDir := path.Join(dir1, "sub.d") err = client.MakeDir(subDir) assert.NoError(t, err) err = client.ChangeDir(path.Dir(subDir)) assert.NoError(t, err) entries, err = client.List("sub.?") require.NoError(t, err) require.Len(t, entries, 1) assert.Contains(t, path.Base(subDir), entries[0].Name) nListEntries, err = client.NameList("sub.?") require.NoError(t, err) require.Len(t, entries, 1) assert.Contains(t, nListEntries, path.Base(subDir)) entries, err = client.List("../*.dir?") require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, path.Join("../", dir2), entries[0].Name) nListEntries, err = client.NameList("../*.dir?") require.NoError(t, err) require.Len(t, entries, 1) assert.Contains(t, nListEntries, path.Join("../", dir2)) err = client.ChangeDir("/") assert.NoError(t, err) entries, err = client.List(path.Join(dir1, "sub.*")) require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, path.Join(dir1, "sub.d"), entries[0].Name) nListEntries, err = client.NameList(path.Join(dir1, "sub.*")) require.NoError(t, err) require.Len(t, entries, 1) assert.Contains(t, nListEntries, path.Join(dir1, "sub.d")) err = client.RemoveDir(subDir) assert.NoError(t, err) err = client.RemoveDir(dir1) assert.NoError(t, err) err = client.RemoveDir(dir2) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } } } func TestStartDirectory(t *testing.T) { startDir := "/start/dir" u := getTestUser() u.Filters.StartDirectory = startDir localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.Filters.StartDirectory = startDir sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { currentDir, err := client.CurrentDir() assert.NoError(t, err) assert.Equal(t, startDir, currentDir) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) entries, err := client.List(".") assert.NoError(t, err) if assert.Len(t, entries, 1) { assert.Equal(t, testFileName, entries[0].Name) } entries, err = client.List("/") assert.NoError(t, err) if assert.Len(t, entries, 1) { assert.Equal(t, "start", entries[0].Name) } err = client.ChangeDirToParent() assert.NoError(t, err) currentDir, err = client.CurrentDir() assert.NoError(t, err) assert.Equal(t, path.Dir(startDir), currentDir) err = client.ChangeDirToParent() assert.NoError(t, err) currentDir, err = client.CurrentDir() assert.NoError(t, err) assert.Equal(t, "/", currentDir) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestLoginEmptyPassword(t *testing.T) { u := getTestUser() u.Password = "" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Password = emptyPwdPlaceholder _, err = getFTPClient(user, true, nil) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestAnonymousUser(t *testing.T) { u := getTestUser() u.Password = "" u.Filters.IsAnonymous = true _, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.Error(t, err) user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.IsAnonymous) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) user.Password = emptyPwdPlaceholder client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission") } err = os.Rename(testFilePath, filepath.Join(user.GetHomeDir(), testFileName)) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = client.MakeDir("adir") if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission") } err = client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestAnonymousGroupInheritance(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.IsAnonymous = true g.UserSettings.Permissions = make(map[string][]string) g.UserSettings.Permissions["/"] = allPerms g.UserSettings.Permissions["/testsub"] = allPerms group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Password = emptyPwdPlaceholder client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission") } err = client.MakeDir("adir") if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission") } err = client.MakeDir("/testsub/adir") if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission") } err = os.Rename(testFilePath, filepath.Join(user.GetHomeDir(), testFileName)) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } user.Password = defaultPassword client, err = getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestMultiFactorAuth(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolFTP}, } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) user.Password = defaultPassword _, err = getFTPClient(user, true, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), dataprovider.ErrInvalidCredentials.Error()) } passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) assert.NoError(t, err) user.Password = defaultPassword + passcode client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } // reusing the same passcode should not work _, err = getFTPClient(user, true, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), dataprovider.ErrInvalidCredentials.Error()) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMustChangePasswordRequirement(t *testing.T) { u := getTestUser() u.Filters.RequirePasswordChange = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) _, err = getFTPClient(user, true, nil) assert.Error(t, err) err = dataprovider.UpdateUserPassword(user.Username, defaultPassword, "", "", "") assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSecondFactorRequirement(t *testing.T) { u := getTestUser() u.Filters.TwoFactorAuthProtocols = []string{common.ProtocolFTP} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) _, err = getFTPClient(user, true, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "second factor authentication is not set") } configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolFTP}, } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) assert.NoError(t, err) user.Password = defaultPassword + passcode client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginInvalidCredentials(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Username = "wrong username" _, err = getFTPClient(user, false, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), dataprovider.ErrInvalidCredentials.Error()) } user.Username = u.Username user.Password = "wrong pwd" _, err = getFTPClient(user, false, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), dataprovider.ErrInvalidCredentials.Error()) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestLoginNonExistentUser(t *testing.T) { user := getTestUser() _, err := getFTPClient(user, false, nil) assert.Error(t, err) } func TestFTPSecurity(t *testing.T) { u := getTestUser() u.Filters.FTPSecurity = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } _, err = getFTPClient(user, false, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "TLS is required") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestGroupFTPSecurity(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.FTPSecurity = 1 group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } _, err = getFTPClient(user, false, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "TLS is required") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestLoginExternalAuth(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) g := getTestGroup() g.UserSettings.Filters.DeniedProtocols = []string{common.ProtocolFTP} group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(u, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } u.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u), os.ModePerm) assert.NoError(t, err) _, err = getFTPClient(u, true, nil) if !assert.Error(t, err) { err := client.Quit() assert.NoError(t, err) } else { assert.Contains(t, err.Error(), "protocol FTP is not allowed") } u.Groups = nil err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u), os.ModePerm) assert.NoError(t, err) u.Username = defaultUsername + "1" client, err = getFTPClient(u, true, nil) if !assert.Error(t, err) { err := client.Quit() assert.NoError(t, err) } else { assert.Contains(t, err.Error(), "invalid credentials") } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, defaultUsername, user.Username) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestPreLoginHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) providerConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) assert.NoError(t, err) client, err := getFTPClient(u, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) // test login with an existing user client, err = getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) assert.NoError(t, err) client, err = getFTPClient(u, false, nil) if !assert.Error(t, err) { err := client.Quit() assert.NoError(t, err) } user.Status = 0 err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) assert.NoError(t, err) client, err = getFTPClient(u, false, nil) if !assert.Error(t, err, "pre-login script returned a disabled user, login must fail") { err := client.Quit() assert.NoError(t, err) } user.Status = 0 user.Filters.FTPSecurity = 1 err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) assert.NoError(t, err) client, err = getFTPClient(u, true, nil) if !assert.Error(t, err) { err := client.Quit() assert.NoError(t, err) } _, err = getFTPClient(user, false, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "TLS is required") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestPreLoginHookReturningAnonymousUser(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() u.Filters.IsAnonymous = true u.Filters.DeniedProtocols = []string{common.ProtocolSSH} u.Password = "" err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) providerConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) // the pre-login hook create the anonymous user client, err := getFTPClient(u, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = client.MakeDir("tdiranonymous") if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission") } err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission") } err = os.Rename(testFilePath, filepath.Join(u.GetHomeDir(), testFileName)) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.IsAnonymous) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) // now the same with an existing user client, err = getFTPClient(u, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission") } err = os.Rename(testFilePath, filepath.Join(u.GetHomeDir(), testFileName)) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestPreDownloadHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldExecuteOn := common.Config.Actions.ExecuteOn oldHook := common.Config.Actions.Hook common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} common.Config.Actions.Hook = preDownloadPath user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } // now return an error from the pre-download hook err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) client, err = getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "permission denied") } err := client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook } func TestPreUploadHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldExecuteOn := common.Config.Actions.ExecuteOn oldHook := common.Config.Actions.Hook common.Config.Actions.ExecuteOn = []string{common.OperationPreUpload} common.Config.Actions.Hook = preUploadPath user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(preUploadPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } // now return an error from the pre-upload hook err = os.WriteFile(preUploadPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) client, err = getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), ftpserver.ErrFileNameNotAllowed.Error()) } err = ftpUploadFile(testFilePath, testFileName+"1", testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), ftpserver.ErrFileNameNotAllowed.Error()) } err := client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook } func TestPostConnectHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } common.Config.PostConnectHook = postConnectPath u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(postConnectPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } err = os.WriteFile(postConnectPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) client, err = getFTPClient(user, true, nil) if !assert.Error(t, err) { err := client.Quit() assert.NoError(t, err) } common.Config.PostConnectHook = "http://127.0.0.1:8079/healthz" client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } common.Config.PostConnectHook = "http://127.0.0.1:8079/notfound" client, err = getFTPClient(user, true, nil) if !assert.Error(t, err) { err := client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.PostConnectHook = "" } //nolint:dupl func TestMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) user := getTestUser() err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) user.Password = "" client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) _, err = getFTPClient(user, false, nil) assert.Error(t, err) err = client.Quit() assert.NoError(t, err) } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.MaxTotalConnections = oldValue } //nolint:dupl func TestMaxPerHostConnections(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 1 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) user := getTestUser() err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) user.Password = "" client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) _, err = getFTPClient(user, false, nil) assert.Error(t, err) err = client.Quit() assert.NoError(t, err) } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.MaxPerHostConnections = oldValue } func TestMaxTransfers(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 2 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) user := getTestUser() err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) user.Password = "" testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) conn, sftpClient, err := getSftpClient(user) assert.NoError(t, err) defer conn.Close() defer sftpClient.Close() f1, err := sftpClient.Create("file1") assert.NoError(t, err) f2, err := sftpClient.Create("file2") assert.NoError(t, err) _, err = f1.Write([]byte(" ")) assert.NoError(t, err) _, err = f2.Write([]byte(" ")) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.Error(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.Error(t, err) err := client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } err = f1.Close() assert.NoError(t, err) err = f2.Close() assert.NoError(t, err) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.MaxPerHostConnections = oldValue } func TestRateLimiter(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 5 cfg.DefenderConfig.ScoreLimitExceeded = 3 cfg.RateLimitersConfig = []common.RateLimiterConfig{ { Average: 1, Period: 1000, Burst: 1, Type: 2, Protocols: []string{common.ProtocolFTP}, GenerateDefenderEvents: true, EntriesSoftLimit: 100, EntriesHardLimit: 150, }, } err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } _, err = getFTPClient(user, true, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "rate limit exceed") } _, err = getFTPClient(user, false, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "rate limit exceed") } _, err = getFTPClient(user, true, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "banned client IP") } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestDefender(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 4 cfg.DefenderConfig.ScoreLimitExceeded = 2 cfg.DefenderConfig.ScoreNoAuth = 1 cfg.DefenderConfig.ScoreValid = 1 err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } // just dial without login ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} client, err = ftp.Dial(ftpServerAddr, ftpOptions...) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) if assert.Len(t, hosts, 1) { host := hosts[0] assert.Empty(t, host.GetBanTime()) assert.Equal(t, 1, host.Score) } user.Password = "wrong_pwd" _, err = getFTPClient(user, false, nil) assert.Error(t, err) hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) if assert.Len(t, hosts, 1) { host := hosts[0] assert.Empty(t, host.GetBanTime()) assert.Equal(t, 2, host.Score) } for i := 0; i < 2; i++ { _, err = getFTPClient(user, false, nil) assert.Error(t, err) } user.Password = defaultPassword _, err = getFTPClient(user, false, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "banned client IP") } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestMaxSessions(t *testing.T) { u := getTestUser() u.MaxSessions = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) _, err = getFTPClient(user, false, nil) assert.Error(t, err) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestZeroBytesTransfers(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, useTLS := range []bool{true, false} { client, err := getFTPClient(user, useTLS, nil) if assert.NoError(t, err) { testFileName := "testfilename" err = checkBasicFTP(client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, "empty_download") err = os.WriteFile(localDownloadPath, []byte(""), os.ModePerm) assert.NoError(t, err) err = ftpUploadFile(localDownloadPath, testFileName, 0, client, 0) assert.NoError(t, err) size, err := client.FileSize(testFileName) assert.NoError(t, err) assert.Equal(t, int64(0), size) err = os.Remove(localDownloadPath) assert.NoError(t, err) assert.NoFileExists(t, localDownloadPath) err = ftpDownloadFile(testFileName, localDownloadPath, 0, client, 0) assert.NoError(t, err) assert.FileExists(t, localDownloadPath) err = client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestDownloadErrors(t *testing.T) { u := getTestUser() u.QuotaFiles = 1 subDir1 := "sub1" subDir2 := "sub2" u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems} u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermDownload} u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/sub2", AllowedPatterns: []string{}, DeniedPatterns: []string{"*.jpg", "*.zip"}, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { testFilePath1 := filepath.Join(user.HomeDir, subDir1, "file.zip") testFilePath2 := filepath.Join(user.HomeDir, subDir2, "file.zip") testFilePath3 := filepath.Join(user.HomeDir, subDir2, "file.jpg") err = os.MkdirAll(filepath.Dir(testFilePath1), os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(filepath.Dir(testFilePath2), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFilePath1, []byte("file1"), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFilePath2, []byte("file2"), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFilePath3, []byte("file3"), os.ModePerm) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(path.Join("/", subDir1, "file.zip"), localDownloadPath, 5, client, 0) assert.Error(t, err) err = ftpDownloadFile(path.Join("/", subDir2, "file.zip"), localDownloadPath, 5, client, 0) assert.Error(t, err) err = ftpDownloadFile(path.Join("/", subDir2, "file.jpg"), localDownloadPath, 5, client, 0) assert.Error(t, err) err = ftpDownloadFile("/missing.zip", localDownloadPath, 5, client, 0) assert.Error(t, err) err = client.Quit() assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUploadErrors(t *testing.T) { u := getTestUser() u.QuotaSize = 65535 subDir1 := "sub1" subDir2 := "sub2" u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems} u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete} u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/sub2", AllowedPatterns: []string{}, DeniedPatterns: []string{"*.zip"}, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := user.QuotaSize err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = client.MakeDir(subDir1) assert.NoError(t, err) err = client.MakeDir(subDir2) assert.NoError(t, err) err = client.ChangeDir(subDir1) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.Error(t, err) err = client.ChangeDirToParent() assert.NoError(t, err) err = client.ChangeDir(subDir2) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName+".zip", testFileSize, client, 0) assert.Error(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.Error(t, err) err = client.ChangeDir("/") assert.NoError(t, err) err = ftpUploadFile(testFilePath, subDir1, testFileSize, client, 0) assert.Error(t, err) // overquota err = ftpUploadFile(testFilePath, testFileName+"1", testFileSize, client, 0) assert.Error(t, err) err = client.Delete(path.Join("/", subDir2, testFileName)) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.Error(t, err) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSFTPBuffered(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaFiles = 100 u.FsConfig.SFTPConfig.BufferSize = 2 u.HomeDir = filepath.Join(os.TempDir(), u.Username) sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(sftpUser, true, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) expectedQuotaSize := testFileSize expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) // overwrite an existing file err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) user, _, err := httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) data := []byte("test data") err = os.WriteFile(testFilePath, data, os.ModePerm) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, int64(len(data)+5), client, 5) if assert.Error(t, err) { assert.Contains(t, err.Error(), "operation unsupported") } err = ftpDownloadFile(testFileName, localDownloadPath, int64(4), client, 5) assert.NoError(t, err) readed, err := os.ReadFile(localDownloadPath) assert.NoError(t, err) assert.Equal(t, []byte("data"), readed) // try to append to a file, it should fail // now append to a file srcFile, err := os.Open(testFilePath) if assert.NoError(t, err) { err = client.Append(testFileName, srcFile) if assert.Error(t, err) { assert.Contains(t, err.Error(), "operation unsupported") } err = srcFile.Close() assert.NoError(t, err) size, err := client.FileSize(testFileName) assert.NoError(t, err) assert.Equal(t, int64(len(data)), size) err = ftpDownloadFile(testFileName, localDownloadPath, int64(len(data)), client, 0) assert.NoError(t, err) } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(sftpUser.GetHomeDir()) assert.NoError(t, err) } func TestResume(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) u = getTestUser() u.FsConfig.OSConfig = sdk.OSFsConfig{ ReadBufferSize: 1, WriteBufferSize: 1, } u.Username += "_buf" u.HomeDir += "_buf" bufferedUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser, bufferedUser} { client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) data := []byte("test data") err = os.WriteFile(testFilePath, data, os.ModePerm) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, int64(len(data)+5), client, 5) assert.NoError(t, err) readed, err := os.ReadFile(filepath.Join(user.GetHomeDir(), testFileName)) assert.NoError(t, err) assert.Equal(t, "test test data", string(readed)) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpDownloadFile(testFileName, localDownloadPath, int64(len(data)), client, 5) assert.NoError(t, err) readed, err = os.ReadFile(localDownloadPath) assert.NoError(t, err) assert.Equal(t, data, readed) err = client.Delete(testFileName) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) assert.NoError(t, err) // now append to a file srcFile, err := os.Open(testFilePath) if assert.NoError(t, err) { err = client.Append(testFileName, srcFile) assert.NoError(t, err) err = srcFile.Close() assert.NoError(t, err) size, err := client.FileSize(testFileName) assert.NoError(t, err) assert.Equal(t, int64(2*len(data)), size) err = ftpDownloadFile(testFileName, localDownloadPath, int64(2*len(data)), client, 0) assert.NoError(t, err) readed, err = os.ReadFile(localDownloadPath) assert.NoError(t, err) expected := append(data, data...) assert.Equal(t, expected, readed) } // append to a new file srcFile, err = os.Open(testFilePath) if assert.NoError(t, err) { newFileName := testFileName + "_new" err = client.Append(newFileName, srcFile) assert.NoError(t, err) err = srcFile.Close() assert.NoError(t, err) size, err := client.FileSize(newFileName) assert.NoError(t, err) assert.Equal(t, int64(len(data)), size) err = ftpDownloadFile(newFileName, localDownloadPath, int64(len(data)), client, 0) assert.NoError(t, err) readed, err = os.ReadFile(localDownloadPath) assert.NoError(t, err) assert.Equal(t, data, readed) } err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(bufferedUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(bufferedUser.GetHomeDir()) assert.NoError(t, err) } //nolint:dupl func TestDeniedLoginMethod(t *testing.T) { u := getTestUser() u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) _, err = getFTPClient(user, false, nil) assert.Error(t, err) user.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyAndPassword} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { assert.NoError(t, checkBasicFTP(client)) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } //nolint:dupl func TestDeniedProtocols(t *testing.T) { u := getTestUser() u.Filters.DeniedProtocols = []string{common.ProtocolFTP} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) _, err = getFTPClient(user, false, nil) assert.Error(t, err) user.Filters.DeniedProtocols = []string{common.ProtocolSSH, common.ProtocolWebDAV} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { assert.NoError(t, checkBasicFTP(client)) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestQuotaLimits(t *testing.T) { u := getTestUser() u.QuotaFiles = 1 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaFiles = 1 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { testFileSize := int64(65535) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testFileSize1 := int64(131072) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) testFileSize2 := int64(32768) testFileName2 := "test_file2.dat" testFilePath2 := filepath.Join(homeBasePath, testFileName2) err = createTestFile(testFilePath2, testFileSize2) assert.NoError(t, err) // test quota files client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = ftpUploadFile(testFilePath, testFileName+".quota", testFileSize, client, 0) //nolint:goconst assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName+".quota1", testFileSize, client, 0) assert.Error(t, err) err = client.Rename(testFileName+".quota", testFileName) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } // test quota size user.QuotaSize = testFileSize - 1 user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, true, nil) if assert.NoError(t, err) { err = ftpUploadFile(testFilePath, testFileName+".quota", testFileSize, client, 0) assert.Error(t, err) err = client.Rename(testFileName, testFileName+".quota") assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } // now test quota limits while uploading the current file, we have 1 bytes remaining user.QuotaSize = testFileSize + 1 user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { err = ftpUploadFile(testFilePath1, testFileName1, testFileSize1, client, 0) assert.Error(t, err) _, err = client.FileSize(testFileName1) assert.Error(t, err) err = client.Rename(testFileName+".quota", testFileName) assert.NoError(t, err) // overwriting an existing file will work if the resulting size is lesser or equal than the current one err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath2, testFileName, testFileSize2, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath1, testFileName, testFileSize1, client, 0) assert.Error(t, err) err = ftpUploadFile(testFilePath1, testFileName, testFileSize1, client, 10) assert.Error(t, err) err = ftpUploadFile(testFilePath2, testFileName, testFileSize2, client, 0) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) err = os.Remove(testFilePath2) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) user.QuotaFiles = 0 user.QuotaSize = 0 _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.QuotaSize = 0 user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestUploadMaxSize(t *testing.T) { testFileSize := int64(65535) u := getTestUser() u.Filters.MaxUploadFileSize = testFileSize + 1 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.Filters.MaxUploadFileSize = testFileSize + 1 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testFileSize1 := int64(131072) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = ftpUploadFile(testFilePath1, testFileName1, testFileSize1, client, 0) assert.Error(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) // now test overwrite an existing file with a size bigger than the allowed one err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName1), testFileSize1) assert.NoError(t, err) err = ftpUploadFile(testFilePath1, testFileName1, testFileSize1, client, 0) assert.Error(t, err) err = client.Quit() assert.NoError(t, err) } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.Filters.MaxUploadFileSize = 65536000 user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestLoginWithIPilters(t *testing.T) { u := getTestUser() u.Filters.DeniedIP = []string{"192.167.0.0/24", "172.18.0.0/16"} u.Filters.AllowedIP = []string{"172.19.0.0/16"} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if !assert.Error(t, err) { err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginWithDatabaseCredentials(t *testing.T) { u := getTestUser() u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account", "private_key": " ", "client_email": "example@iam.gserviceaccount.com" }`) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) assert.NotEmpty(t, user.FsConfig.GCSConfig.Credentials.GetPayload()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginInvalidFs(t *testing.T) { u := getTestUser() u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if !assert.Error(t, err) { err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestClientClose(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) stats := common.Connections.GetStats("") if assert.Len(t, stats, 1) { common.Connections.Close(stats[0].ConnectionID, "") assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRename(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { testDir := "adir" testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = client.MakeDir(testDir) assert.NoError(t, err) err = client.Rename(testFileName, path.Join("missing", testFileName)) assert.Error(t, err) err = client.Rename(testFileName, path.Join(testDir, testFileName)) assert.NoError(t, err) size, err := client.FileSize(path.Join(testDir, testFileName)) assert.NoError(t, err) assert.Equal(t, testFileSize, size) if runtime.GOOS != osWindows { otherDir := "dir" err = client.MakeDir(otherDir) assert.NoError(t, err) err = client.MakeDir(path.Join(otherDir, testDir)) assert.NoError(t, err) code, response, err := client.SendCommand("SITE CHMOD 0001 %v", otherDir) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) err = client.Rename(testDir, path.Join(otherDir, testDir)) assert.Error(t, err) code, response, err = client.SendCommand("SITE CHMOD 755 %v", otherDir) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) } err = client.Quit() assert.NoError(t, err) } user.Permissions[path.Join("/", testDir)] = []string{dataprovider.PermListItems} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { err = client.Rename(path.Join(testDir, testFileName), testFileName) assert.Error(t, err) err := client.Quit() assert.NoError(t, err) } err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Permissions = make(map[string][]string) user.Permissions["/"] = allPerms user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestSymlink(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) for _, user := range []dataprovider.User{localUser, sftpUser} { err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) code, _, err := client.SendCommand("SITE SYMLINK %v %v", testFileName, testFileName+".link") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) if runtime.GOOS != osWindows { testDir := "adir" otherDir := "dir" err = client.MakeDir(otherDir) assert.NoError(t, err) err = client.MakeDir(path.Join(otherDir, testDir)) assert.NoError(t, err) code, response, err := client.SendCommand("SITE CHMOD 0001 %v", otherDir) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) code, _, err = client.SendCommand("SITE SYMLINK %v %v", testDir, path.Join(otherDir, testDir)) assert.NoError(t, err) assert.Equal(t, ftp.StatusFileUnavailable, code) code, response, err = client.SendCommand("SITE CHMOD 755 %v", otherDir) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) } err = client.Quit() assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestStat(t *testing.T) { u := getTestUser() u.Permissions["/subdir"] = []string{dataprovider.PermUpload} localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { subDir := "subdir" testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = client.MakeDir(subDir) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join("/", subDir, testFileName), testFileSize, client, 0) assert.Error(t, err) size, err := client.FileSize(testFileName) assert.NoError(t, err) assert.Equal(t, testFileSize, size) _, err = client.FileSize(path.Join("/", subDir, testFileName)) assert.Error(t, err) _, err = client.FileSize("missing file") assert.Error(t, err) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestUploadOverwriteVfolder(t *testing.T) { u := getTestUser() u.QuotaFiles = 1000 vdir := "/vdir" mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdir, QuotaSize: -1, QuotaFiles: -1, }) err = os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join(vdir, testFileName), testFileSize, client, 0) assert.NoError(t, err) folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), folder.UsedQuotaSize) assert.Equal(t, 0, folder.UsedQuotaFiles) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) err = ftpUploadFile(testFilePath, path.Join(vdir, testFileName), testFileSize, client, 0) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), folder.UsedQuotaSize) assert.Equal(t, 0, folder.UsedQuotaFiles) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestTransferQuotaLimits(t *testing.T) { u := getTestUser() u.DownloadDataTransfer = 1 u.UploadDataTransfer = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(524288) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), ftpserver.ErrStorageExceeded.Error()) } err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) } err = client.Quit() assert.NoError(t, err) } testFileSize = int64(600000) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) user.DownloadDataTransfer = 2 user.UploadDataTransfer = 2 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.Error(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.Error(t, err) err = client.Quit() assert.NoError(t, err) } err = os.Remove(localDownloadPath) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestAllocateAvailable(t *testing.T) { u := getTestUser() mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/vdir", QuotaSize: 110, }) err = os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("allo 2000000") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Done !", response) code, response, err = client.SendCommand("AVBL /vdir") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "110", response) code, _, err = client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) err = client.Quit() assert.NoError(t, err) } user.QuotaSize = 100 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := user.QuotaSize - 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) code, response, err := client.SendCommand("allo 1000") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Done !", response) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) code, response, err = client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "1", response) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } user.TotalDataTransfer = 1 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "1", response) err = client.Quit() assert.NoError(t, err) } user.TotalDataTransfer = 0 user.UploadDataTransfer = 5 user.QuotaSize = 6 * 1024 * 1024 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "5242880", response) err = client.Quit() assert.NoError(t, err) } user.TotalDataTransfer = 0 user.UploadDataTransfer = 5 user.QuotaSize = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "5242880", response) err = client.Quit() assert.NoError(t, err) } user.Filters.MaxUploadFileSize = 100 user.QuotaSize = 0 user.TotalDataTransfer = 0 user.UploadDataTransfer = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("allo 10000") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Done !", response) code, response, err = client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "100", response) err = client.Quit() assert.NoError(t, err) } user.QuotaSize = 50 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "0", response) } user.QuotaSize = 1000 user.Filters.MaxUploadFileSize = 1 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("AVBL") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "1", response) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestAvailableSFTPFs(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(sftpUser, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("AVBL /") assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) avblSize, err := strconv.ParseInt(response, 10, 64) assert.NoError(t, err) assert.Greater(t, avblSize, int64(0)) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestChtimes(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) mtime := time.Now().Format("20060102150405") code, response, err := client.SendCommand("MFMT %v %v", mtime, testFileName) assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, fmt.Sprintf("Modify=%v; %v", mtime, testFileName), response) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestMODEType(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("MODE s") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotImplementedParameter, code) assert.Equal(t, "Unsupported mode", response) code, response, err = client.SendCommand("MODE S") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Using stream mode", response) code, _, err = client.SendCommand("MODE Z") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotImplementedParameter, code) code, _, err = client.SendCommand("MODE SS") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotImplementedParameter, code) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSTAT(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(131072) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testDir := "testdir" err = client.MakeDir(testDir) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join(testDir, testFileName+"_1"), testFileSize, client, 0) assert.NoError(t, err) code, response, err := client.SendCommand("STAT %s", testDir) assert.NoError(t, err) assert.Equal(t, ftp.StatusDirectory, code) assert.Contains(t, response, fmt.Sprintf("STAT %s", testDir)) assert.Contains(t, response, testFileName) assert.Contains(t, response, testFileName+"_1") assert.Contains(t, response, "End") err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestChown(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("chown is not supported on Windows") } user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(131072) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) code, response, err := client.SendCommand("SITE CHOWN 1000:1000 %v", testFileName) assert.NoError(t, err) assert.Equal(t, ftp.StatusFileUnavailable, code) assert.Equal(t, "Couldn't chown: operation unsupported", response) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestChmod(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("chmod is partially supported on Windows") } u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(131072) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) code, response, err := client.SendCommand("SITE CHMOD 600 %v", testFileName) assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "SITE CHMOD command successful", response) fi, err := os.Stat(filepath.Join(user.HomeDir, testFileName)) if assert.NoError(t, err) { assert.Equal(t, os.FileMode(0600), fi.Mode().Perm()) } err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestCombineDisabled(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) code, response, err := client.SendCommand("COMB file file.1 file.2") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotImplemented, code) assert.Equal(t, "COMB support is disabled", response) err = client.Quit() assert.NoError(t, err) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestActiveModeDisabled(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClientImplicitTLS(user) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) code, response, err := client.SendCommand("PORT 10,2,0,2,4,31") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotAvailable, code) assert.Equal(t, "PORT command is disabled", response) code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|") assert.NoError(t, err) assert.Equal(t, ftp.StatusNotAvailable, code) assert.Equal(t, "EPRT command is disabled", response) err = client.Quit() assert.NoError(t, err) } client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { code, response, err := client.SendCommand("PORT 10,2,0,2,4,31") assert.NoError(t, err) assert.Equal(t, ftp.StatusBadArguments, code) assert.Equal(t, "Your request does not meet the configured security requirements", response) code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|") assert.NoError(t, err) assert.Equal(t, ftp.StatusBadArguments, code) assert.Equal(t, "Your request does not meet the configured security requirements", response) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSITEDisabled(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClientImplicitTLS(user) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) code, response, err := client.SendCommand("SITE CHMOD 600 afile.txt") assert.NoError(t, err) assert.Equal(t, ftp.StatusBadCommand, code) assert.Equal(t, "SITE support is disabled", response) err = client.Quit() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestHASH(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) u = getTestUserWithCryptFs() u.Username += "_crypt" cryptUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser, cryptUser} { client, err := getFTPClientImplicitTLS(user) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(131072) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) h := sha256.New() f, err := os.Open(testFilePath) assert.NoError(t, err) _, err = io.Copy(h, f) assert.NoError(t, err) hash := hex.EncodeToString(h.Sum(nil)) err = f.Close() assert.NoError(t, err) code, response, err := client.SendCommand("XSHA256 %v", testFileName) assert.NoError(t, err) assert.Equal(t, ftp.StatusRequestedFileActionOK, code) assert.Contains(t, response, hash) code, response, err = client.SendCommand("HASH %v", testFileName) assert.NoError(t, err) assert.Equal(t, ftp.StatusFile, code) assert.Contains(t, response, hash) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(cryptUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(cryptUser.GetHomeDir()) assert.NoError(t, err) } func TestCombine(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { client, err := getFTPClientImplicitTLS(user) if assert.NoError(t, err) { testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(131072) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = checkBasicFTP(client) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName+".1", testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, testFileName+".2", testFileSize, client, 0) assert.NoError(t, err) code, response, err := client.SendCommand("COMB %v %v %v", testFileName, testFileName+".1", testFileName+".2") assert.NoError(t, err) if user.Username == defaultUsername { assert.Equal(t, ftp.StatusRequestedFileActionOK, code) assert.Equal(t, "COMB succeeded!", response) } else { assert.Equal(t, ftp.StatusFileUnavailable, code) assert.Contains(t, response, "COMB is not supported for this filesystem") } err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestClientCertificateAuthRevokedCert(t *testing.T) { u := getTestUser() u.Username = tlsClient2Username u.Filters.TLSUsername = sdk.TLSUsernameCN user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, ClientSessionCache: tls.NewLRUClientSessionCache(0), } tlsCert, err := tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) _, err = getFTPClientWithSessionReuse(user, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "bad certificate") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestClientCertificateAuth(t *testing.T) { u := getTestUser() u.Username = tlsClient1Username u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) // TLS username is not enabled, mutual TLS should fail _, err = getFTPClient(user, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "login method password is not allowed") } user.Filters.TLSUsername = sdk.TLSUsernameCN user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err := getFTPClient(user, true, tlsConfig) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } // now use a valid certificate with a CN different from username u = getTestUser() u.Username = tlsClient2Username u.Filters.TLSUsername = sdk.TLSUsernameCN u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} user2, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) _, err = getFTPClient(user2, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "does not match username") } // add the certs to the user user2.Filters.TLSUsername = sdk.TLSUsernameNone user2.Filters.TLSCerts = []string{client2Crt, client1Crt} user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user2, true, tlsConfig) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } user2.Filters.TLSCerts = []string{client2Crt} user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") assert.NoError(t, err) _, err = getFTPClient(user2, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "TLS certificate is not valid") } // now disable certificate authentication user.Filters.DeniedLoginMethods = append(user.Filters.DeniedLoginMethods, dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodTLSCertificateAndPwd) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = getFTPClient(user, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "login method TLSCertificate+password is not allowed") } // disable FTP protocol user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} user.Filters.DeniedProtocols = append(user.Filters.DeniedProtocols, common.ProtocolFTP) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = getFTPClient(user, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "protocol FTP is not allowed") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user2.GetHomeDir()) assert.NoError(t, err) _, err = getFTPClient(user, true, tlsConfig) assert.Error(t, err) } func TestClientCertificateAndPwdAuth(t *testing.T) { u := getTestUser() u.Username = tlsClient1Username u.Filters.TLSUsername = sdk.TLSUsernameCN u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificate} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) client, err := getFTPClient(user, true, tlsConfig) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) } _, err = getFTPClient(user, true, nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "login method password is not allowed") } user.Password = defaultPassword + "1" _, err = getFTPClient(user, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid credentials") } tlsCert, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) assert.NoError(t, err) tlsConfig.Certificates = []tls.Certificate{tlsCert} _, err = getFTPClient(user, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "bad certificate") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestExternalAuthWithClientCert(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() u.Username = tlsClient1Username u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, dataprovider.LoginMethodPassword) u.Filters.TLSUsername = sdk.TLSUsernameCN err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 8 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) // external auth not called, auth scope is 8 _, err = getFTPClient(u, true, nil) assert.Error(t, err) _, _, err = httpdtest.GetUserByUsername(u.Username, http.StatusNotFound) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) client, err := getFTPClient(u, true, tlsConfig) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, u.Username, user.Username) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) u.Username = tlsClient2Username _, err = getFTPClient(u, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid credentials") } err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestPreLoginHookWithClientCert(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() u.Username = tlsClient1Username u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, dataprovider.LoginMethodPassword) u.Filters.TLSUsername = sdk.TLSUsernameCN err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) providerConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) _, _, err = httpdtest.GetUserByUsername(tlsClient1Username, http.StatusNotFound) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) client, err := getFTPClient(u, true, tlsConfig) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } user, _, err := httpdtest.GetUserByUsername(tlsClient1Username, http.StatusOK) assert.NoError(t, err) // test login with an existing user client, err = getFTPClient(user, true, tlsConfig) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) err := client.Quit() assert.NoError(t, err) } u.Username = tlsClient2Username err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) _, err = getFTPClient(u, true, tlsConfig) if assert.Error(t, err) { assert.Contains(t, err.Error(), "does not match username") } user2, _, err := httpdtest.GetUserByUsername(tlsClient2Username, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user2.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestNestedVirtualFolders(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") folderNameCrypt := filepath.Base(mappedPathCrypt) vdirCryptPath := "/vdir/crypt" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameCrypt, }, VirtualPath: vdirCryptPath, }) mappedPath := filepath.Join(os.TempDir(), "local") folderName := filepath.Base(mappedPath) vdirPath := "/vdir/local" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, }) mappedPathNested := filepath.Join(os.TempDir(), "nested") folderNameNested := filepath.Base(mappedPathNested) vdirNestedPath := "/vdir/crypt/nested" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameNested, }, VirtualPath: vdirNestedPath, QuotaFiles: -1, QuotaSize: -1, }) f1 := vfs.BaseVirtualFolder{ Name: folderNameCrypt, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, MappedPath: mappedPathCrypt, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) f3 := vfs.BaseVirtualFolder{ Name: folderNameNested, MappedPath: mappedPathNested, } _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client, err := getFTPClient(sftpUser, false, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join("/vdir", testFileName), testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join(vdirCryptPath, testFileName), testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = ftpUploadFile(testFilePath, path.Join(vdirNestedPath, testFileName), testFileSize, client, 0) assert.NoError(t, err) err = ftpDownloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client, 0) assert.NoError(t, err) err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPathCrypt) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) err = os.RemoveAll(mappedPathNested) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func checkBasicFTP(client *ftp.ServerConn) error { _, err := client.CurrentDir() if err != nil { return err } err = client.NoOp() if err != nil { return err } _, err = client.List(".") if err != nil { return err } return nil } func ftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *ftp.ServerConn, offset uint64) error { srcFile, err := os.Open(localSourcePath) if err != nil { return err } defer srcFile.Close() if offset > 0 { err = client.StorFrom(remoteDestPath, srcFile, offset) } else { err = client.Stor(remoteDestPath, srcFile) } if err != nil { return err } if expectedSize > 0 { size, err := client.FileSize(remoteDestPath) if err != nil { return err } if size != expectedSize { return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", size, expectedSize) } } return nil } func ftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *ftp.ServerConn, offset uint64) error { downloadDest, err := os.Create(localDestPath) if err != nil { return err } defer downloadDest.Close() var r *ftp.Response if offset > 0 { r, err = client.RetrFrom(remoteSourcePath, offset) } else { r, err = client.Retr(remoteSourcePath) } if err != nil { return err } defer r.Close() written, err := io.Copy(downloadDest, r) if err != nil { return err } if written != expectedSize { return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", written, expectedSize) } return nil } func getFTPClientImplicitTLS(user dataprovider.User) (*ftp.ServerConn, error) { ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } ftpOptions = append(ftpOptions, ftp.DialWithTLS(tlsConfig)) ftpOptions = append(ftpOptions, ftp.DialWithDisabledEPSV(true)) client, err := ftp.Dial(ftpSrvAddrTLS, ftpOptions...) if err != nil { return nil, err } pwd := defaultPassword if user.Password != "" { pwd = user.Password } err = client.Login(user.Username, pwd) if err != nil { return nil, err } return client, err } func getFTPClientWithSessionReuse(user dataprovider.User, tlsConfig *tls.Config, dialOptions ...ftp.DialOption, ) (*ftp.ServerConn, error) { ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} ftpOptions = append(ftpOptions, dialOptions...) if tlsConfig == nil { tlsConfig = &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, ClientSessionCache: tls.NewLRUClientSessionCache(0), } } ftpOptions = append(ftpOptions, ftp.DialWithExplicitTLS(tlsConfig)) client, err := ftp.Dial(ftpSrvAddrTLSResumption, ftpOptions...) if err != nil { return nil, err } pwd := defaultPassword if user.Password != "" { if user.Password == emptyPwdPlaceholder { pwd = "" } else { pwd = user.Password } } err = client.Login(user.Username, pwd) if err != nil { return nil, err } return client, err } func getFTPClient(user dataprovider.User, useTLS bool, tlsConfig *tls.Config, dialOptions ...ftp.DialOption, ) (*ftp.ServerConn, error) { ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} ftpOptions = append(ftpOptions, dialOptions...) if useTLS { if tlsConfig == nil { tlsConfig = &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } } ftpOptions = append(ftpOptions, ftp.DialWithExplicitTLS(tlsConfig)) } client, err := ftp.Dial(ftpServerAddr, ftpOptions...) if err != nil { return nil, err } pwd := defaultPassword if user.Password != "" { if user.Password == emptyPwdPlaceholder { pwd = "" } else { pwd = user.Password } } err = client.Login(user.Username, pwd) if err != nil { return nil, err } return client, err } func waitTCPListening(address string) { for { conn, err := net.Dial("tcp", address) if err != nil { logger.WarnToConsole("tcp server %v not listening: %v", address, err) time.Sleep(100 * time.Millisecond) continue } logger.InfoToConsole("tcp server %v now listening", address) conn.Close() break } } func waitNoConnections() { time.Sleep(50 * time.Millisecond) for len(common.Connections.GetStats("")) > 0 { time.Sleep(50 * time.Millisecond) } } func getTestGroup() dataprovider.Group { return dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "test_group", Description: "test group description", }, } } func getTestUser() dataprovider.User { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: defaultUsername, Password: defaultPassword, HomeDir: filepath.Join(homeBasePath, defaultUsername), Status: 1, ExpirationDate: 0, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = allPerms return user } func getTestSFTPUser() dataprovider.User { u := getTestUser() u.Username = u.Username + "_sftp" u.FsConfig.Provider = sdk.SFTPFilesystemProvider u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) return u } func getTestUserWithHTTPFs() dataprovider.User { u := getTestUser() u.FsConfig.Provider = sdk.HTTPFilesystemProvider u.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: fmt.Sprintf("http://127.0.0.1:%d/api/v1", httpFsPort), Username: defaultHTTPFsUsername, }, } return u } func getExtAuthScriptContent(user dataprovider.User) []byte { extAuthContent := []byte("#!/bin/sh\n\n") extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("if test \"$SFTPGO_AUTHD_USERNAME\" = \"%v\"; then\n", user.Username))...) u, _ := json.Marshal(user) extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) extAuthContent = append(extAuthContent, []byte("else\n")...) extAuthContent = append(extAuthContent, []byte("echo '{\"username\":\"\"}'\n")...) extAuthContent = append(extAuthContent, []byte("fi\n")...) return extAuthContent } func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte { content := []byte("#!/bin/sh\n\n") if nonJSONResponse { content = append(content, []byte("echo 'text response'\n")...) return content } if len(user.Username) > 0 { u, _ := json.Marshal(user) content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) } return content } func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { var sftpClient *sftp.Client config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } if user.Password != "" { config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} } else { config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { return conn, sftpClient, err } sftpClient, err = sftp.NewClient(conn) if err != nil { conn.Close() } return conn, sftpClient, err } func getExitCodeScriptContent(exitCode int) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) return content } func createTestFile(path string, size int64) error { baseDir := filepath.Dir(path) if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(baseDir, os.ModePerm) if err != nil { return err } } content := make([]byte, size) _, err := rand.Read(content) if err != nil { return err } return os.WriteFile(path, content, os.ModePerm) } func writeCerts(certPath, keyPath, caCrtPath, caCRLPath string) error { err := os.WriteFile(certPath, []byte(ftpsCert), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing FTPS certificate: %v", err) return err } err = os.WriteFile(keyPath, []byte(ftpsKey), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing FTPS private key: %v", err) return err } err = os.WriteFile(caCrtPath, []byte(caCRT), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing FTPS CA crt: %v", err) return err } err = os.WriteFile(caCRLPath, []byte(caCRL), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing FTPS CRL: %v", err) return err } return nil } func generateTOTPPasscode(secret string, algo otp.Algorithm) (string, error) { return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ Period: 30, Skew: 1, Digits: otp.DigitsSix, Algorithm: algo, }) } func startHTTPFs() { go func() { if err := httpdtest.StartTestHTTPFs(httpFsPort, nil); err != nil { logger.ErrorToConsole("could not start HTTPfs test server: %v", err) os.Exit(1) } }() waitTCPListening(fmt.Sprintf(":%d", httpFsPort)) } ================================================ FILE: internal/ftpd/handler.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package ftpd import ( "errors" "fmt" "io" "os" "path" "strings" "time" ftpserver "github.com/fclairamb/ftpserverlib" "github.com/spf13/afero" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) var ( errNotImplemented = errors.New("not implemented") errCOMBNotSupported = errors.New("COMB is not supported for this filesystem") ) // Connection details for an FTP connection. // It implements common.ActiveConnection and ftpserver.ClientDriver interfaces type Connection struct { *common.BaseConnection clientContext ftpserver.ClientContext doWildcardListDir bool } func (c *Connection) getFTPMode() string { if c.clientContext == nil { return "" } switch c.clientContext.GetLastDataChannel() { case ftpserver.DataChannelActive: return "active" case ftpserver.DataChannelPassive: return "passive" } return "" } // GetClientVersion returns the connected client's version. // It returns "Unknown" if the client does not advertise its // version func (c *Connection) GetClientVersion() string { version := c.clientContext.GetClientVersion() if len(version) > 0 { return version } return "Unknown" } // GetLocalAddress returns local connection address func (c *Connection) GetLocalAddress() string { return c.clientContext.LocalAddr().String() } // GetRemoteAddress returns the connected client's address func (c *Connection) GetRemoteAddress() string { return c.clientContext.RemoteAddr().String() } // Disconnect disconnects the client func (c *Connection) Disconnect() error { return c.clientContext.Close() } // GetCommand returns the last received FTP command func (c *Connection) GetCommand() string { return c.clientContext.GetLastCommand() } // Create is not implemented we use ClientDriverExtentionFileTransfer func (c *Connection) Create(_ string) (afero.File, error) { return nil, errNotImplemented } // Mkdir creates a directory using the connection filesystem func (c *Connection) Mkdir(name string, _ os.FileMode) error { c.UpdateLastActivity() name = util.CleanPath(name) return c.CreateDir(name, true) } // MkdirAll is not implemented, we don't need it func (c *Connection) MkdirAll(_ string, _ os.FileMode) error { return errNotImplemented } // Open is not implemented we use ClientDriverExtentionFileTransfer and ClientDriverExtensionFileList func (c *Connection) Open(_ string) (afero.File, error) { return nil, errNotImplemented } // OpenFile is not implemented we use ClientDriverExtentionFileTransfer func (c *Connection) OpenFile(_ string, _ int, _ os.FileMode) (afero.File, error) { return nil, errNotImplemented } // Remove removes a file. // We implements ClientDriverExtensionRemoveDir for directories func (c *Connection) Remove(name string) error { c.UpdateLastActivity() name = util.CleanPath(name) fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { return err } var fi os.FileInfo if fi, err = fs.Lstat(p); err != nil { c.Log(logger.LevelError, "failed to remove file %q: stat error: %+v", p, err) return c.GetFsError(fs, err) } if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { c.Log(logger.LevelError, "cannot remove %q is not a file/symlink", p) return c.GetGenericError(nil) } return c.RemoveFile(fs, p, name, fi) } // RemoveAll is not implemented, we don't need it func (c *Connection) RemoveAll(_ string) error { return errNotImplemented } // Rename renames a file or a directory func (c *Connection) Rename(oldname, newname string) error { c.UpdateLastActivity() oldname = util.CleanPath(oldname) newname = util.CleanPath(newname) return c.BaseConnection.Rename(oldname, newname) } // Stat returns a FileInfo describing the named file/directory, or an error, // if any happens func (c *Connection) Stat(name string) (os.FileInfo, error) { c.UpdateLastActivity() name = util.CleanPath(name) c.doWildcardListDir = false if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(name)) { return nil, c.GetPermissionDeniedError() } fi, err := c.DoStat(name, 0, true) if err != nil { if c.isListDirWithWildcards(path.Base(name)) { c.doWildcardListDir = true return vfs.NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } return nil, err } return fi, nil } // Name returns the name of this connection func (c *Connection) Name() string { return c.GetID() } // Chown changes the uid and gid of the named file func (c *Connection) Chown(_ string, _, _ int) error { c.UpdateLastActivity() return common.ErrOpUnsupported /*p, err := c.Fs.ResolvePath(name) if err != nil { return c.GetFsError(err) } attrs := common.StatAttributes{ Flags: common.StatAttrUIDGID, UID: uid, GID: gid, } return c.SetStat(p, name, &attrs)*/ } // Chmod changes the mode of the named file/directory func (c *Connection) Chmod(name string, mode os.FileMode) error { c.UpdateLastActivity() name = util.CleanPath(name) attrs := common.StatAttributes{ Flags: common.StatAttrPerms, Mode: mode, } return c.SetStat(name, &attrs) } // Chtimes changes the access and modification times of the named file func (c *Connection) Chtimes(name string, atime time.Time, mtime time.Time) error { c.UpdateLastActivity() name = util.CleanPath(name) attrs := common.StatAttributes{ Flags: common.StatAttrTimes, Atime: atime, Mtime: mtime, } return c.SetStat(name, &attrs) } // GetAvailableSpace implements ClientDriverExtensionAvailableSpace interface func (c *Connection) GetAvailableSpace(dirName string) (int64, error) { c.UpdateLastActivity() dirName = util.CleanPath(dirName) diskQuota, transferQuota := c.HasSpace(false, false, path.Join(dirName, "fakefile.txt")) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { return 0, nil } if diskQuota.AllowedSize == 0 && transferQuota.AllowedULSize == 0 && transferQuota.AllowedTotalSize == 0 { // no quota restrictions if c.User.Filters.MaxUploadFileSize > 0 { return c.User.Filters.MaxUploadFileSize, nil } fs, p, err := c.GetFsAndResolvedPath(dirName) if err != nil { return 0, err } statVFS, err := fs.GetAvailableDiskSize(p) if err != nil { return 0, c.GetFsError(fs, err) } return int64(statVFS.FreeSpace()), nil } allowedDiskSize := diskQuota.AllowedSize allowedUploadSize := transferQuota.AllowedULSize if transferQuota.AllowedTotalSize > 0 { allowedUploadSize = transferQuota.AllowedTotalSize } allowedSize := allowedDiskSize if allowedSize == 0 { allowedSize = allowedUploadSize } else { if allowedUploadSize > 0 && allowedUploadSize < allowedSize { allowedSize = allowedUploadSize } } // the available space is the minimum between MaxUploadFileSize, if setted, // and quota allowed size if c.User.Filters.MaxUploadFileSize > 0 { if c.User.Filters.MaxUploadFileSize < allowedSize { return c.User.Filters.MaxUploadFileSize, nil } } return allowedSize, nil } // AllocateSpace implements ClientDriverExtensionAllocate interface func (c *Connection) AllocateSpace(_ int) error { c.UpdateLastActivity() // we treat ALLO as NOOP see RFC 959 return nil } // RemoveDir implements ClientDriverExtensionRemoveDir func (c *Connection) RemoveDir(name string) error { c.UpdateLastActivity() name = util.CleanPath(name) return c.BaseConnection.RemoveDir(name) } // Symlink implements ClientDriverExtensionSymlink func (c *Connection) Symlink(oldname, newname string) error { c.UpdateLastActivity() oldname = util.CleanPath(oldname) newname = util.CleanPath(newname) return c.CreateSymlink(oldname, newname) } // ReadDir implements ClientDriverExtensionFilelist func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { c.UpdateLastActivity() name = util.CleanPath(name) if c.doWildcardListDir { c.doWildcardListDir = false baseName := path.Base(name) // we only support wildcards for the last path level, for example: // - *.xml is supported // - dir*/*.xml is not supported name = path.Dir(name) c.clientContext.SetListPath(name) lister, err := c.ListDir(name) if err != nil { return nil, err } patternLister := &patternDirLister{ DirLister: lister, pattern: baseName, lastCommand: c.clientContext.GetLastCommand(), dirName: name, connectionPath: util.CleanPath(c.clientContext.Path()), } return consumeDirLister(patternLister) } lister, err := c.ListDir(name) if err != nil { return nil, err } return consumeDirLister(lister) } // GetHandle implements ClientDriverExtentionFileTransfer func (c *Connection) GetHandle(name string, flags int, offset int64) (ftpserver.FileTransfer, error) { c.UpdateLastActivity() name = util.CleanPath(name) fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { return nil, err } if c.GetCommand() == "COMB" && !vfs.IsLocalOsFs(fs) { return nil, errCOMBNotSupported } if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { c.Log(logger.LevelInfo, "denying transfer due to count limits") return nil, c.GetPermissionDeniedError() } if flags&os.O_WRONLY != 0 { return c.uploadFile(fs, p, name, flags) } return c.downloadFile(fs, p, name, offset) } func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int64) (ftpserver.FileTransfer, error) { if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(ftpPath)) { return nil, c.GetPermissionDeniedError() } transferQuota := c.GetTransferQuota() if !transferQuota.HasDownloadSpace() { c.Log(logger.LevelInfo, "denying file read due to quota limits") return nil, c.GetReadQuotaExceededError() } if ok, policy := c.User.IsFileAllowed(ftpPath); !ok { c.Log(logger.LevelWarn, "reading file %q is not allowed", ftpPath) return nil, c.GetErrorForDeniedFile(policy) } if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreDownload, fsPath, ftpPath, 0, 0); err != nil { c.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", ftpPath, err) return nil, c.GetPermissionDeniedError() } file, r, cancelFn, err := fs.Open(fsPath, offset) if err != nil { c.Log(logger.LevelError, "could not open file %q for reading: %+v", fsPath, err) return nil, c.GetFsError(fs, err) } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, nil, r, offset) return t, nil } func (c *Connection) uploadFile(fs vfs.Fs, fsPath, ftpPath string, flags int) (ftpserver.FileTransfer, error) { if ok, _ := c.User.IsFileAllowed(ftpPath); !ok { c.Log(logger.LevelWarn, "writing file %q is not allowed", ftpPath) return nil, ftpserver.ErrFileNameNotAllowed } filePath := fsPath if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { filePath = fs.GetAtomicUploadPath(fsPath) } stat, statErr := fs.Lstat(fsPath) if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(ftpPath)) { return nil, fmt.Errorf("%w, no upload permission", ftpserver.ErrFileNameNotAllowed) } return c.handleFTPUploadToNewFile(fs, flags, fsPath, filePath, ftpPath) } if statErr != nil { c.Log(logger.LevelError, "error performing file stat %q: %+v", fsPath, statErr) return nil, c.GetFsError(fs, statErr) } // This happen if we upload a file that has the same name of an existing directory if stat.IsDir() { c.Log(logger.LevelError, "attempted to open a directory for writing to: %q", fsPath) return nil, c.GetOpUnsupportedError() } if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(ftpPath)) { return nil, fmt.Errorf("%w, no overwrite permission", ftpserver.ErrFileNameNotAllowed) } return c.handleFTPUploadToExistingFile(fs, flags, fsPath, filePath, stat.Size(), ftpPath) } func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, flags int, resolvedPath, filePath, requestPath string) (ftpserver.FileTransfer, error) { diskQuota, transferQuota := c.HasSpace(true, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, ftpserver.ErrStorageExceeded } if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, 0, 0); err != nil { c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) return nil, ftpserver.ErrFileNameNotAllowed } file, w, cancelFn, err := fs.Create(filePath, flags, c.GetCreateChecks(requestPath, true, false)) if err != nil { c.Log(logger.LevelError, "error creating file %q, flags %v: %+v", resolvedPath, flags, err) return nil, c.GetFsError(fs, err) } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, w, nil, 0) return t, nil } func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolvedPath, filePath string, fileSize int64, requestPath string) (ftpserver.FileTransfer, error) { var err error diskQuota, transferQuota := c.HasSpace(false, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, ftpserver.ErrStorageExceeded } minWriteOffset := int64(0) // ftpserverlib sets: // - os.O_WRONLY | os.O_APPEND for APPE and COMB // - os.O_WRONLY | os.O_CREATE for REST. // - os.O_WRONLY | os.O_CREATE | os.O_TRUNC if the command is not APPE and REST = 0 // so if we don't have O_TRUNC is a resume. isResume := flags&os.O_TRUNC == 0 // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before maxWriteSize, err := c.GetMaxWriteSize(diskQuota, isResume, fileSize, vfs.IsUploadResumeSupported(fs, fileSize)) if err != nil { c.Log(logger.LevelDebug, "unable to get max write size: %v", err) return nil, err } if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, fileSize, flags); err != nil { c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) return nil, ftpserver.ErrFileNameNotAllowed } if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { _, _, err = fs.Rename(resolvedPath, filePath, 0) if err != nil { c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v", resolvedPath, filePath, err) return nil, c.GetFsError(fs, err) } } file, w, cancelFn, err := fs.Create(filePath, flags, c.GetCreateChecks(requestPath, false, isResume)) if err != nil { c.Log(logger.LevelError, "error opening existing file, flags: %v, source: %q, err: %+v", flags, filePath, err) return nil, c.GetFsError(fs, err) } initialSize := int64(0) truncatedSize := int64(0) // bytes truncated and not included in quota if isResume { c.Log(logger.LevelDebug, "resuming upload requested, file path: %q initial size: %v", filePath, fileSize) minWriteOffset = fileSize initialSize = fileSize if vfs.IsSFTPFs(fs) && fs.IsUploadResumeSupported() { // we need this since we don't allow resume with wrong offset, we should fix this in pkg/sftp file.Seek(initialSize, io.SeekStart) //nolint:errcheck // for sftp seek simply set the offset } } else { if vfs.HasTruncateSupport(fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false) } else { dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { initialSize = fileSize truncatedSize = fileSize } } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, w, nil, minWriteOffset) return t, nil } func (c *Connection) isListDirWithWildcards(name string) bool { if strings.ContainsAny(name, "*?[]^") { lastCommand := c.clientContext.GetLastCommand() return lastCommand == "LIST" || lastCommand == "NLST" } return false } func getPathRelativeTo(base, target string) string { var sb strings.Builder for { if base == target { return sb.String() } if !strings.HasSuffix(base, "/") { base += "/" } if strings.HasPrefix(target, base) { sb.WriteString(strings.TrimPrefix(target, base)) return sb.String() } if base == "/" || base == "./" { return target } sb.WriteString("../") base = path.Dir(path.Clean(base)) } } type patternDirLister struct { vfs.DirLister pattern string lastCommand string dirName string connectionPath string } func (l *patternDirLister) Next(limit int) ([]os.FileInfo, error) { for { files, err := l.DirLister.Next(limit) if len(files) == 0 { return files, err } validIdx := 0 var relativeBase string if l.lastCommand != "NLST" { relativeBase = getPathRelativeTo(l.connectionPath, l.dirName) } for _, fi := range files { match, errMatch := path.Match(l.pattern, fi.Name()) if errMatch != nil { return nil, errMatch } if match { files[validIdx] = vfs.NewFileInfo(path.Join(relativeBase, fi.Name()), fi.IsDir(), fi.Size(), fi.ModTime(), true) validIdx++ } } files = files[:validIdx] if err != nil || len(files) > 0 { return files, err } } } func consumeDirLister(lister vfs.DirLister) ([]os.FileInfo, error) { defer lister.Close() var results []os.FileInfo for { files, err := lister.Next(vfs.ListerBatchSize) finished := errors.Is(err, io.EOF) results = append(results, files...) if err != nil && !finished { return results, err } if finished { lister.Close() break } } return results, nil } ================================================ FILE: internal/ftpd/internal_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package ftpd import ( "crypto/tls" "crypto/x509" "errors" "fmt" "io/fs" "net" "os" "path/filepath" "runtime" "testing" "time" "github.com/eikenb/pipeat" ftpserver "github.com/fclairamb/ftpserverlib" "github.com/pires/go-proxyproto" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( ftpsCert = `-----BEGIN CERTIFICATE----- MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM 3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY /8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI dV4vKmHUzwK/eIx+8Ay3neE= -----END CERTIFICATE-----` ftpsKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= -----END EC PRIVATE KEY-----` caCRT = `-----BEGIN CERTIFICATE----- MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U 4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY /tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w kK8u1YiiVenmoQ== -----END CERTIFICATE-----` caKey = `-----BEGIN RSA PRIVATE KEY----- MIIJKgIBAAKCAgEA7WHW216mfi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7x b64rkpdzx1aWetSiCrEyc3D1v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvD OBUYZgtMqHZzpE6xRrqQ84zhyzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKz n/2uEVt33qmO85WtN3RzbSqLCdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj 7B5P5MeamkkogwbExUjdHp3U4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZ De67V/Q8iB2May1k7zBz1ZtbKF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmOb cn8AIfH6smLQrn0C3cs7CYfoNlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLq R/BJTbbyXUB0imne1u00fuzbS7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb7 8x7ivdyXSF5LVQJ1JvhhWu6iM6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpB P8d2jcRZVUVrSXGc2mAGuGOY/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2Z xugCXULtRWJ9p4C9zUl40HEyOQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEA AQKCAgEA4x0OoceG54ZrVxifqVaQd8qw3uRmUKUMIMdfuMlsdideeLO97ynmSlRY 00kGo/I4Lp6mNEjI9gUie9+uBrcUhri4YLcujHCH+YlNnCBDbGjwbe0ds9SLCWaa KztZHMSlW5Q4Bqytgu+MpOnxSgqjlOk+vz9TcGFKVnUkHIkAcqKFJX8gOFxPZA/t Ob1kJaz4kuv5W2Kur/ISKvQtvFvOtQeV0aJyZm8LqXnvS4cPI7yN4329NDU0HyDR y/deqS2aqV4zII3FFqbz8zix/m1xtVQzWCugZGMKrz0iuJMfNeCABb8rRGc6GsZz +465v/kobqgeyyneJ1s5rMFrLp2o+dwmnIVMNsFDUiN1lIZDHLvlgonaUO3IdTZc 9asamFWKFKUMgWqM4zB1vmUO12CKowLNIIKb0L+kf1ixaLLDRGf/f9vLtSHE+oyx lATiS18VNA8+CGsHF6uXMRwf2auZdRI9+s6AAeyRISSbO1khyWKHo+bpOvmPAkDR nknTjbYgkoZOV+mrsU5oxV8s6vMkuvA3rwFhT2gie8pokuACFcCRrZi9MVs4LmUQ u0GYTHvp2WJUjMWBm6XX7Hk3g2HV842qpk/mdtTjNsXws81djtJPn4I/soIXSgXz pY3SvKTuOckP9OZVF0yqKGeZXKpD288PKpC+MAg3GvEJaednagECggEBAPsfLwuP L1kiDjXyMcRoKlrQ6Q/zBGyBmJbZ5uVGa02+XtYtDAzLoVupPESXL0E7+r8ZpZ39 0dV4CEJKpbVS/BBtTEkPpTK5kz778Ib04TAyj+YLhsZjsnuja3T5bIBZXFDeDVDM 0ZaoFoKpIjTu2aO6pzngsgXs6EYbo2MTuJD3h0nkGZsICL7xvT9Mw0P1p2Ftt/hN +jKk3vN220wTWUsq43AePi45VwK+PNP12ZXv9HpWDxlPo3j0nXtgYXittYNAT92u BZbFAzldEIX9WKKZgsWtIzLaASjVRntpxDCTby/nlzQ5dw3DHU1DV3PIqxZS2+Oe KV+7XFWgZ44YjYECggEBAPH+VDu3QSrqSahkZLkgBtGRkiZPkZFXYvU6kL8qf5wO Z/uXMeqHtznAupLea8I4YZLfQim/NfC0v1cAcFa9Ckt9g3GwTSirVcN0AC1iOyv3 /hMZCA1zIyIcuUplNr8qewoX71uPOvCNH0dix77423mKFkJmNwzy4Q+rV+qkRdLn v+AAgh7g5N91pxNd6LQJjoyfi1Ka6rRP2yGXM5v7QOwD16eN4JmExUxX1YQ7uNuX pVS+HRxnBquA+3/DB1LtBX6pa2cUa+LRUmE/NCPHMvJcyuNkYpJKlNTd9vnbfo0H RNSJSWm+aGxDFMjuPjV3JLj2OdKMPwpnXdh2vBZCPpMCggEAM+yTvrEhmi2HgLIO hkz/jP2rYyfdn04ArhhqPLgd0dpuI5z24+Jq/9fzZT9ZfwSW6VK1QwDLlXcXRhXH Q8Hf6smev3CjuORURO61IkKaGWwrAucZPAY7ToNQ4cP9ImDXzMTNPgrLv3oMBYJR V16X09nxX+9NABqnQG/QjdjzDc6Qw7+NZ9f2bvzvI5qMuY2eyW91XbtJ45ThoLfP ymAp03gPxQwL0WT7z85kJ3OrROxzwaPvxU0JQSZbNbqNDPXmFTiECxNDhpRAAWlz 1DC5Vg2l05fkMkyPdtD6nOQWs/CYSfB5/EtxiX/xnBszhvZUIe6KFvuKFIhaJD5h iykagQKCAQEAoBRm8k3KbTIo4ZzvyEq4V/+dF3zBRczx6FkCkYLygXBCNvsQiR2Y BjtI8Ijz7bnQShEoOmeDriRTAqGGrspEuiVgQ1+l2wZkKHRe/aaij/Zv+4AuhH8q uZEYvW7w5Uqbs9SbgQzhp2kjTNy6V8lVnjPLf8cQGZ+9Y9krwktC6T5m/i435WdN 38h7amNP4XEE/F86Eb3rDrZYtgLIoCF4E+iCyxMehU+AGH1uABhls9XAB6vvo+8/ SUp8lEqWWLP0U5KNOtYWfCeOAEiIHDbUq+DYUc4BKtbtV1cx3pzlPTOWw6XBi5Lq jttdL4HyYvnasAQpwe8GcMJqIRyCVZMiwwKCAQEAhQTTS3CC8PwcoYrpBdTjW1ck vVFeF1YbfqPZfYxASCOtdx6wRnnEJ+bjqntagns9e88muxj9UhxSL6q9XaXQBD8+ 2AmKUxphCZQiYFZcTucjQEQEI2nN+nAKgRrUSMMGiR8Ekc2iFrcxBU0dnSohw+aB PbMKVypQCREu9PcDFIp9rXQTeElbaNsIg1C1w/SQjODbmN/QFHTVbRODYqLeX1J/ VcGsykSIq7hv6bjn7JGkr2JTdANbjk9LnMjMdJFsKRYxPKkOQfYred6Hiojp5Sor PW5am8ejnNSPhIfqQp3uV3KhwPDKIeIpzvrB4uPfTjQWhekHCb8cKSWux3flqw== -----END RSA PRIVATE KEY-----` caCRL = `-----BEGIN X509 CRL----- MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV +jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn w7bXJGfJcIMrrKs= -----END X509 CRL-----` client1Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au MU3Bo0A= -----END CERTIFICATE-----` client1Key = `-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi 85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== -----END RSA PRIVATE KEY-----` // client 2 crt is revoked client2Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB 4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO 78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv /ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR E0+r2+4= -----END CERTIFICATE-----` client2Key = `-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu 4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl /owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz 3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D 8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl -----END RSA PRIVATE KEY-----` ) var ( configDir = filepath.Join(".", "..", "..") ) type mockFTPClientContext struct { lastDataChannel ftpserver.DataChannel remoteIP string localIP string extra any } func (cc *mockFTPClientContext) Path() string { return "" } func (cc *mockFTPClientContext) SetPath(_ string) {} func (cc *mockFTPClientContext) SetListPath(_ string) {} func (cc *mockFTPClientContext) SetDebug(_ bool) {} func (cc *mockFTPClientContext) Debug() bool { return false } func (cc *mockFTPClientContext) ID() uint32 { return 1 } func (cc *mockFTPClientContext) RemoteAddr() net.Addr { ip := "127.0.0.1" if cc.remoteIP != "" { ip = cc.remoteIP } return &net.IPAddr{IP: net.ParseIP(ip)} } func (cc *mockFTPClientContext) LocalAddr() net.Addr { ip := "127.0.0.1" if cc.localIP != "" { ip = cc.localIP } return &net.IPAddr{IP: net.ParseIP(ip)} } func (cc *mockFTPClientContext) GetClientVersion() string { return "mock version" } func (cc *mockFTPClientContext) Close() error { return nil } func (cc *mockFTPClientContext) HasTLSForControl() bool { return false } func (cc *mockFTPClientContext) HasTLSForTransfers() bool { return false } func (cc *mockFTPClientContext) SetTLSRequirement(_ ftpserver.TLSRequirement) error { return nil } func (cc *mockFTPClientContext) GetLastCommand() string { return "" } func (cc *mockFTPClientContext) GetLastDataChannel() ftpserver.DataChannel { return cc.lastDataChannel } func (cc *mockFTPClientContext) SetExtra(extra any) { cc.extra = extra } func (cc *mockFTPClientContext) Extra() any { return cc.extra } // MockOsFs mockable OsFs type MockOsFs struct { vfs.Fs err error statErr error isAtomicUploadSupported bool } // Name returns the name for the Fs implementation func (fs MockOsFs) Name() string { return "mockOsFs" } // IsUploadResumeSupported returns true if resuming uploads is supported func (MockOsFs) IsUploadResumeSupported() bool { return false } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (MockOsFs) IsConditionalUploadResumeSupported(_ int64) bool { return false } // IsAtomicUploadSupported returns true if atomic upload is supported func (fs MockOsFs) IsAtomicUploadSupported() bool { return fs.isAtomicUploadSupported } // Stat returns a FileInfo describing the named file func (fs MockOsFs) Stat(name string) (os.FileInfo, error) { if fs.statErr != nil { return nil, fs.statErr } return os.Stat(name) } // Lstat returns a FileInfo describing the named file func (fs MockOsFs) Lstat(name string) (os.FileInfo, error) { if fs.statErr != nil { return nil, fs.statErr } return os.Lstat(name) } // Remove removes the named file or (empty) directory. func (fs MockOsFs) Remove(name string, _ bool) error { if fs.err != nil { return fs.err } return os.Remove(name) } // Rename renames (moves) source to target func (fs MockOsFs) Rename(source, target string, _ int) (int, int64, error) { if fs.err != nil { return -1, -1, fs.err } err := os.Rename(source, target) return -1, -1, err } func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs { return &MockOsFs{ Fs: vfs.NewOsFs(connectionID, rootDir, "", nil), err: err, statErr: statErr, isAtomicUploadSupported: atomicUpload, } } func TestInitialization(t *testing.T) { oldMgr := certMgr certMgr = nil binding := Binding{ Port: 2121, } c := &Configuration{ Bindings: []Binding{binding}, CertificateFile: "acert", CertificateKeyFile: "akey", } assert.False(t, binding.HasProxy()) assert.Equal(t, util.I18nFTPTLSDisabled, binding.GetTLSDescription()) err := c.Initialize(configDir) assert.Error(t, err) c.CertificateFile = "" c.CertificateKeyFile = "" c.BannerFile = "afile" server := NewServer(c, configDir, binding, 0) assert.Equal(t, version.GetServerVersion("_", false), server.initialMsg) _, err = server.GetTLSConfig() assert.Error(t, err) binding.TLSMode = 1 server = NewServer(c, configDir, binding, 0) _, err = server.GetSettings() assert.Error(t, err) binding.PassiveConnectionsSecurity = 100 binding.ActiveConnectionsSecurity = 100 server = NewServer(c, configDir, binding, 0) _, err = server.GetSettings() if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid passive_connections_security") } binding.PassiveConnectionsSecurity = 1 server = NewServer(c, configDir, binding, 0) _, err = server.GetSettings() if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid active_connections_security") } binding = Binding{ Port: 2121, ForcePassiveIP: "192.168.1", } server = NewServer(c, configDir, binding, 0) _, err = server.GetSettings() if assert.Error(t, err) { assert.Contains(t, err.Error(), "is not valid") } binding.ForcePassiveIP = "::ffff:192.168.89.9" err = binding.checkPassiveIP() assert.NoError(t, err) assert.Equal(t, "192.168.89.9", binding.ForcePassiveIP) binding.ForcePassiveIP = "::1" err = binding.checkPassiveIP() if assert.Error(t, err) { assert.Contains(t, err.Error(), "is not a valid IPv4 address") } err = ReloadCertificateMgr() assert.NoError(t, err) binding = Binding{ Port: 2121, ClientAuthType: 1, } assert.Equal(t, util.I18nFTPTLSDisabled, binding.GetTLSDescription()) certPath := filepath.Join(os.TempDir(), "test_ftpd.crt") keyPath := filepath.Join(os.TempDir(), "test_ftpd.key") binding.CertificateFile = certPath binding.CertificateKeyFile = keyPath keyPairs := []common.TLSKeyPair{ { Cert: certPath, Key: keyPath, ID: binding.GetAddress(), }, } certMgr, err = common.NewCertManager(keyPairs, configDir, "") require.NoError(t, err) assert.Equal(t, util.I18nFTPTLSMixed, binding.GetTLSDescription()) server = NewServer(c, configDir, binding, 0) cfg, err := server.GetTLSConfig() require.NoError(t, err) assert.Equal(t, tls.RequireAndVerifyClientCert, cfg.ClientAuth) certMgr = oldMgr } func TestServerGetSettings(t *testing.T) { oldConfig := common.Config oldMgr := certMgr binding := Binding{ Port: 2121, ApplyProxyConfig: true, } c := &Configuration{ Bindings: []Binding{binding}, PassivePortRange: PortRange{ Start: 10000, End: 10000, }, } assert.False(t, binding.HasProxy()) server := NewServer(c, configDir, binding, 0) settings, err := server.GetSettings() assert.NoError(t, err) if ranger, ok := settings.PassiveTransferPortRange.(*ftpserver.PortRange); ok { assert.Equal(t, 10000, ranger.Start) assert.Equal(t, 10000, ranger.End) } c.PassivePortRange.End = 11000 settings, err = server.GetSettings() assert.NoError(t, err) if ranger, ok := settings.PassiveTransferPortRange.(*ftpserver.PortRange); ok { assert.Equal(t, 10000, ranger.Start) assert.Equal(t, 11000, ranger.End) } common.Config.ProxyProtocol = 1 _, err = server.GetSettings() assert.Error(t, err) server.binding.Port = 8021 assert.Equal(t, util.I18nFTPTLSDisabled, binding.GetTLSDescription()) _, err = server.GetTLSConfig() assert.Error(t, err) // TLS configured but cert manager has no certificate binding.TLSMode = 1 assert.Equal(t, util.I18nFTPTLSExplicit, binding.GetTLSDescription()) binding.TLSMode = 2 assert.Equal(t, util.I18nFTPTLSImplicit, binding.GetTLSDescription()) certPath := filepath.Join(os.TempDir(), "test_ftpd.crt") keyPath := filepath.Join(os.TempDir(), "test_ftpd.key") err = os.WriteFile(certPath, []byte(ftpsCert), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(keyPath, []byte(ftpsKey), os.ModePerm) assert.NoError(t, err) keyPairs := []common.TLSKeyPair{ { Cert: certPath, Key: keyPath, ID: common.DefaultTLSKeyPaidID, }, } certMgr, err = common.NewCertManager(keyPairs, configDir, "") require.NoError(t, err) common.Config.ProxyAllowed = nil c.CertificateFile = certPath c.CertificateKeyFile = keyPath server = NewServer(c, configDir, binding, 0) server.binding.Port = 9021 settings, err = server.GetSettings() assert.NoError(t, err) assert.NotNil(t, settings.Listener) listener, err := net.Listen("tcp", ":0") assert.NoError(t, err) listener, err = server.WrapPassiveListener(listener) assert.NoError(t, err) _, ok := listener.(*proxyproto.Listener) assert.True(t, ok) common.Config = oldConfig certMgr = oldMgr } func TestUserInvalidParams(t *testing.T) { u := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: "invalid", }, } binding := Binding{ Port: 2121, } c := &Configuration{ Bindings: []Binding{binding}, PassivePortRange: PortRange{ Start: 10000, End: 11000, }, } server := NewServer(c, configDir, binding, 3) _, err := server.validateUser(u, &mockFTPClientContext{}, dataprovider.LoginMethodPassword) assert.Error(t, err) u.Username = "a" u.HomeDir = filepath.Clean(os.TempDir()) subDir := "subdir" mappedPath1 := filepath.Join(os.TempDir(), "vdir1") vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir1", subDir) vdirPath2 := "/vdir2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath1, }, VirtualPath: vdirPath1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath2, }, VirtualPath: vdirPath2, }) _, err = server.validateUser(u, &mockFTPClientContext{}, dataprovider.LoginMethodPassword) assert.Error(t, err) u.VirtualFolders = nil _, err = server.validateUser(u, &mockFTPClientContext{}, dataprovider.LoginMethodPassword) assert.Error(t, err) } func TestFTPMode(t *testing.T) { connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolFTP, "", "", dataprovider.User{}), } assert.Empty(t, connection.getFTPMode()) connection.clientContext = &mockFTPClientContext{lastDataChannel: ftpserver.DataChannelActive} assert.Equal(t, "active", connection.getFTPMode()) connection.clientContext = &mockFTPClientContext{lastDataChannel: ftpserver.DataChannelPassive} assert.Equal(t, "passive", connection.getFTPMode()) connection.clientContext = &mockFTPClientContext{lastDataChannel: 0} assert.Empty(t, connection.getFTPMode()) } func TestClientVersion(t *testing.T) { mockCC := &mockFTPClientContext{} connID := fmt.Sprintf("2_%v", mockCC.ID()) user := dataprovider.User{} connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), clientContext: mockCC, } err := common.Connections.Add(connection) assert.NoError(t, err) stats := common.Connections.GetStats("") if assert.Len(t, stats, 1) { assert.Equal(t, "mock version", stats[0].ClientVersion) common.Connections.Remove(connection.GetID()) } assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestDriverMethodsNotImplemented(t *testing.T) { mockCC := &mockFTPClientContext{} connID := fmt.Sprintf("2_%v", mockCC.ID()) user := dataprovider.User{} connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), clientContext: mockCC, } _, err := connection.Create("") assert.EqualError(t, err, errNotImplemented.Error()) err = connection.MkdirAll("", os.ModePerm) assert.EqualError(t, err, errNotImplemented.Error()) _, err = connection.Open("") assert.EqualError(t, err, errNotImplemented.Error()) _, err = connection.OpenFile("", 0, os.ModePerm) assert.EqualError(t, err, errNotImplemented.Error()) err = connection.RemoveAll("") assert.EqualError(t, err, errNotImplemented.Error()) assert.Equal(t, connection.GetID(), connection.Name()) } func TestExtraData(t *testing.T) { mockCC := mockFTPClientContext{} _, ok := mockCC.Extra().(*tlsState) require.False(t, ok) mockCC.SetExtra(&tlsState{ LoginWithMutualTLS: false, Version: tls.VersionName(tls.VersionTLS13), Cipher: tls.CipherSuiteName(tls.TLS_AES_128_GCM_SHA256), KEX: tls.X25519MLKEM768.String(), }) state, ok := mockCC.Extra().(*tlsState) require.True(t, ok) require.False(t, state.LoginWithMutualTLS) require.Equal(t, tls.VersionName(tls.VersionTLS13), state.Version) require.Equal(t, tls.CipherSuiteName(tls.TLS_AES_128_GCM_SHA256), state.Cipher) require.Equal(t, tls.X25519MLKEM768.String(), state.KEX) mockCC.SetExtra(&tlsState{ LoginWithMutualTLS: true, }) state, ok = mockCC.Extra().(*tlsState) require.True(t, ok) require.True(t, state.LoginWithMutualTLS) } func TestResolvePathErrors(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: "invalid", }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} mockCC := &mockFTPClientContext{} connID := fmt.Sprintf("%v", mockCC.ID()) connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), clientContext: mockCC, } err := connection.Mkdir("", os.ModePerm) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } err = connection.Remove("") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } err = connection.RemoveDir("") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } err = connection.Rename("", "") assert.ErrorIs(t, err, common.ErrOpUnsupported) err = connection.Symlink("", "") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } _, err = connection.Stat("") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } err = connection.Chmod("", os.ModePerm) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } err = connection.Chtimes("", time.Now(), time.Now()) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } _, err = connection.ReadDir("") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } _, err = connection.GetHandle("", 0, 0) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } _, err = connection.GetAvailableSpace("") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } } func TestUploadFileStatError(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("this test is not available on Windows") } user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user", HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} mockCC := &mockFTPClientContext{} connID := fmt.Sprintf("%v", mockCC.ID()) fs := vfs.NewOsFs(connID, user.HomeDir, "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), clientContext: mockCC, } testFile := filepath.Join(user.HomeDir, "test", "testfile") err := os.MkdirAll(filepath.Dir(testFile), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFile, []byte("data"), os.ModePerm) assert.NoError(t, err) err = os.Chmod(filepath.Dir(testFile), 0001) assert.NoError(t, err) _, err = connection.uploadFile(fs, testFile, "test", 0) assert.Error(t, err) err = os.Chmod(filepath.Dir(testFile), os.ModePerm) assert.NoError(t, err) err = os.RemoveAll(filepath.Dir(testFile)) assert.NoError(t, err) } func TestAVBLErrors(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user", HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} mockCC := &mockFTPClientContext{} connID := fmt.Sprintf("%v", mockCC.ID()) connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), clientContext: mockCC, } _, err := connection.GetAvailableSpace("/") assert.NoError(t, err) _, err = connection.GetAvailableSpace("/missing-path") assert.Error(t, err) assert.True(t, errors.Is(err, fs.ErrNotExist)) } func TestUploadOverwriteErrors(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user", HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} mockCC := &mockFTPClientContext{} connID := fmt.Sprintf("%v", mockCC.ID()) fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir()) connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), clientContext: mockCC, } flags := 0 flags |= os.O_APPEND _, err := connection.handleFTPUploadToExistingFile(fs, flags, "", "", 0, "") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrOpUnsupported.Error()) } f, err := os.CreateTemp("", "temp") assert.NoError(t, err) err = f.Close() assert.NoError(t, err) flags = 0 flags |= os.O_CREATE flags |= os.O_TRUNC tr, err := connection.handleFTPUploadToExistingFile(fs, flags, f.Name(), f.Name(), 123, f.Name()) if assert.NoError(t, err) { transfer := tr.(*transfer) transfers := connection.GetTransfers() if assert.Equal(t, 1, len(transfers)) { assert.Equal(t, transfers[0].ID, transfer.GetID()) assert.Equal(t, int64(123), transfer.InitialSize) err = transfer.Close() assert.NoError(t, err) assert.Equal(t, 0, len(connection.GetTransfers())) } } err = os.Remove(f.Name()) assert.NoError(t, err) _, err = connection.handleFTPUploadToExistingFile(fs, os.O_TRUNC, filepath.Join(os.TempDir(), "sub", "file"), filepath.Join(os.TempDir(), "sub", "file1"), 0, "/sub/file1") assert.Error(t, err) fs = vfs.NewOsFs(connID, user.GetHomeDir(), "", nil) _, err = connection.handleFTPUploadToExistingFile(fs, 0, "missing1", "missing2", 0, "missing") assert.Error(t, err) } func TestTransferErrors(t *testing.T) { testfile := "testfile" file, err := os.Create(testfile) assert.NoError(t, err) user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user", HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} mockCC := &mockFTPClientContext{} connID := fmt.Sprintf("%v", mockCC.ID()) fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir()) connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), } baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) tr := newTransfer(baseTransfer, nil, nil, 0) err = tr.Close() assert.NoError(t, err) _, err = tr.Seek(10, 0) assert.Error(t, err) buf := make([]byte, 64) _, err = tr.Read(buf) assert.Error(t, err) err = tr.Close() if assert.Error(t, err) { assert.EqualError(t, err, common.ErrTransferClosed.Error()) } assert.Len(t, connection.GetTransfers(), 0) r, _, err := pipeat.Pipe() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) tr = newTransfer(baseTransfer, nil, vfs.NewPipeReader(r), 10) pos, err := tr.Seek(10, 0) assert.NoError(t, err) assert.Equal(t, pos, tr.expectedOffset) err = tr.closeIO() assert.NoError(t, err) r, w, err := pipeat.Pipe() assert.NoError(t, err) pipeWriter := vfs.NewPipeWriter(w) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) tr.Connection.RemoveTransfer(tr) tr = newTransfer(baseTransfer, pipeWriter, nil, 0) err = r.Close() assert.NoError(t, err) errFake := fmt.Errorf("fake upload error") go func() { time.Sleep(100 * time.Millisecond) pipeWriter.Done(errFake) }() err = tr.closeIO() assert.EqualError(t, err, errFake.Error()) _, err = tr.Seek(1, 0) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrOpUnsupported.Error()) } tr.Connection.RemoveTransfer(tr) err = os.Remove(testfile) assert.NoError(t, err) } func TestVerifyTLSConnection(t *testing.T) { oldCertMgr := certMgr caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") certPath := filepath.Join(os.TempDir(), "test.crt") keyPath := filepath.Join(os.TempDir(), "test.key") err := os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(certPath, []byte(ftpsCert), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(keyPath, []byte(ftpsKey), os.ModePerm) assert.NoError(t, err) keyPairs := []common.TLSKeyPair{ { Cert: certPath, Key: keyPath, ID: common.DefaultTLSKeyPaidID, }, } certMgr, err = common.NewCertManager(keyPairs, "", "ftp_test") assert.NoError(t, err) certMgr.SetCARevocationLists([]string{caCrlPath}) err = certMgr.LoadCRLs() assert.NoError(t, err) crt, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) x509crt, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) server := Server{} state := tls.ConnectionState{ PeerCertificates: []*x509.Certificate{x509crt}, } err = server.verifyTLSConnection(state) assert.Error(t, err) // no verified certification chain err = server.VerifyTLSConnectionState(nil, state) assert.NoError(t, err) server.binding.ClientAuthType = 1 err = server.VerifyTLSConnectionState(nil, state) assert.Error(t, err) crt, err = tls.X509KeyPair([]byte(caCRT), []byte(caKey)) assert.NoError(t, err) x509CAcrt, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crt, x509CAcrt}) err = server.verifyTLSConnection(state) assert.NoError(t, err) crt, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) assert.NoError(t, err) x509crtRevoked, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crtRevoked, x509CAcrt}) state.PeerCertificates = []*x509.Certificate{x509crtRevoked} err = server.verifyTLSConnection(state) assert.EqualError(t, err, common.ErrCrtRevoked.Error()) err = os.Remove(caCrlPath) assert.NoError(t, err) err = os.Remove(certPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) certMgr = oldCertMgr } func TestCiphers(t *testing.T) { b := Binding{ TLSCipherSuites: []string{}, } b.setCiphers() require.Equal(t, util.GetTLSCiphersFromNames(nil), b.ciphers) b.TLSCipherSuites = []string{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"} b.setCiphers() require.Len(t, b.ciphers, 2) require.Equal(t, []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384}, b.ciphers) } func TestPassiveIPResolver(t *testing.T) { b := Binding{ PassiveIPOverrides: []PassiveIPOverride{ {}, }, } err := b.checkPassiveIP() assert.Error(t, err) assert.Contains(t, err.Error(), "passive IP networks override cannot be empty") b = Binding{ PassiveIPOverrides: []PassiveIPOverride{ { IP: "invalid ip", }, }, } err = b.checkPassiveIP() assert.Error(t, err) assert.Contains(t, err.Error(), "is not valid") b = Binding{ PassiveIPOverrides: []PassiveIPOverride{ { IP: "192.168.1.1", Networks: []string{"192.168.1.0/24", "invalid cidr"}, }, }, } err = b.checkPassiveIP() assert.Error(t, err) assert.Contains(t, err.Error(), "invalid passive IP networks override") b = Binding{ ForcePassiveIP: "192.168.2.1", PassiveIPOverrides: []PassiveIPOverride{ { IP: "::ffff:192.168.1.1", Networks: []string{"192.168.1.0/24"}, }, }, } err = b.checkPassiveIP() assert.NoError(t, err) assert.NotEmpty(t, b.PassiveIPOverrides[0].GetNetworksAsString()) assert.Equal(t, "192.168.1.1", b.PassiveIPOverrides[0].IP) require.Len(t, b.PassiveIPOverrides[0].parsedNetworks, 1) ip := net.ParseIP("192.168.1.2") assert.True(t, b.PassiveIPOverrides[0].parsedNetworks[0](ip)) ip = net.ParseIP("192.168.0.2") assert.False(t, b.PassiveIPOverrides[0].parsedNetworks[0](ip)) mockCC := &mockFTPClientContext{ remoteIP: "192.168.1.10", localIP: "192.168.1.3", } passiveIP, err := b.passiveIPResolver(mockCC) assert.NoError(t, err) assert.Equal(t, "192.168.1.1", passiveIP) b.PassiveIPOverrides[0].IP = "" passiveIP, err = b.passiveIPResolver(mockCC) assert.NoError(t, err) assert.Equal(t, "192.168.1.3", passiveIP) mockCC.remoteIP = "172.16.2.3" passiveIP, err = b.passiveIPResolver(mockCC) assert.NoError(t, err) assert.Equal(t, b.ForcePassiveIP, passiveIP) } func TestRelativePath(t *testing.T) { rel := getPathRelativeTo("/testpath", "/testpath") assert.Empty(t, rel) rel = getPathRelativeTo("/", "/") assert.Empty(t, rel) rel = getPathRelativeTo("/", "/dir/sub") assert.Equal(t, "dir/sub", rel) rel = getPathRelativeTo("./", "/dir/sub") assert.Equal(t, "/dir/sub", rel) rel = getPathRelativeTo("/sub", "/dir/sub") assert.Equal(t, "../dir/sub", rel) rel = getPathRelativeTo("/dir", "/dir/sub") assert.Equal(t, "sub", rel) rel = getPathRelativeTo("/dir/sub", "/dir") assert.Equal(t, "../", rel) rel = getPathRelativeTo("dir", "/dir1") assert.Equal(t, "/dir1", rel) rel = getPathRelativeTo("", "/dir2") assert.Equal(t, "dir2", rel) rel = getPathRelativeTo(".", "/dir2") assert.Equal(t, "/dir2", rel) rel = getPathRelativeTo("/dir3", "dir3") assert.Equal(t, "dir3", rel) } func TestConfigsFromProvider(t *testing.T) { err := dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) c := Configuration{} err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) configs := dataprovider.Configs{ ACME: &dataprovider.ACMEConfigs{ Domain: "domain.com", Email: "info@domain.com", HTTP01Challenge: dataprovider.ACMEHTTP01Challenge{Port: 80}, Protocols: 2, }, } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) util.CertsBasePath = "" // crt and key empty err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) util.CertsBasePath = filepath.Clean(os.TempDir()) // crt not found err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs := c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) crtPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".crt") err = os.WriteFile(crtPath, nil, 0666) assert.NoError(t, err) // key not found err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) keyPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".key") err = os.WriteFile(keyPath, nil, 0666) assert.NoError(t, err) // acme cert used err = c.loadFromProvider() assert.NoError(t, err) assert.Equal(t, configs.ACME.Domain, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 1) // protocols does not match configs.ACME.Protocols = 5 err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) c.acmeDomain = "" err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) err = os.Remove(crtPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) util.CertsBasePath = "" err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestPassiveHost(t *testing.T) { b := Binding{ PassiveHost: "invalid hostname", } _, err := b.getPassiveIP(nil) assert.Error(t, err) b.PassiveHost = "localhost" ip, err := b.getPassiveIP(nil) assert.NoError(t, err, ip) assert.Equal(t, "127.0.0.1", ip) } ================================================ FILE: internal/ftpd/server.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package ftpd import ( "crypto/tls" "crypto/x509" "errors" "fmt" "net" "os" "path/filepath" "slices" ftpserver "github.com/fclairamb/ftpserverlib" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) // tlsState tracks TLS connection state for a client type tlsState struct { // LoginWithMutualTLS indicates whether the user logged in using TLS certificate authentication LoginWithMutualTLS bool Version string Cipher string KEX string } // Server implements the ftpserverlib MainDriver interface type Server struct { ID int config *Configuration initialMsg string statusBanner string binding Binding tlsConfig *tls.Config } // NewServer returns a new FTP server driver func NewServer(config *Configuration, configDir string, binding Binding, id int) *Server { binding.setCiphers() vers := version.GetServerVersion("_", false) server := &Server{ config: config, initialMsg: vers, statusBanner: fmt.Sprintf("%s FTP Server", vers), binding: binding, ID: id, } if config.BannerFile != "" { bannerFilePath := config.BannerFile if !filepath.IsAbs(bannerFilePath) { bannerFilePath = filepath.Join(configDir, bannerFilePath) } bannerContent, err := os.ReadFile(bannerFilePath) if err == nil { server.initialMsg = util.BytesToString(bannerContent) } else { logger.WarnToConsole("unable to read FTPD banner file: %v", err) logger.Warn(logSender, "", "unable to read banner file: %v", err) } } server.buildTLSConfig() return server } // GetSettings returns FTP server settings func (s *Server) GetSettings() (*ftpserver.Settings, error) { if err := s.binding.checkPassiveIP(); err != nil { return nil, err } if err := s.binding.checkSecuritySettings(); err != nil { return nil, err } var portRange *ftpserver.PortRange if s.config.PassivePortRange.Start > 0 && s.config.PassivePortRange.End >= s.config.PassivePortRange.Start { portRange = &ftpserver.PortRange{ Start: s.config.PassivePortRange.Start, End: s.config.PassivePortRange.End, } } var ftpListener net.Listener if s.binding.HasProxy() { listener, err := net.Listen("tcp", s.binding.GetAddress()) if err != nil { logger.Warn(logSender, "", "error starting listener on address %v: %v", s.binding.GetAddress(), err) return nil, err } ftpListener, err = common.Config.GetProxyListener(listener) if err != nil { logger.Warn(logSender, "", "error enabling proxy listener: %v", err) return nil, err } if s.binding.TLSMode == 2 && s.tlsConfig != nil { ftpListener = tls.NewListener(ftpListener, s.tlsConfig) } } if !s.binding.isTLSModeValid() { return nil, fmt.Errorf("unsupported TLS mode: %d", s.binding.TLSMode) } if s.binding.TLSMode > 0 && certMgr == nil { return nil, errors.New("to enable TLS you need to provide a certificate") } settings := &ftpserver.Settings{ Listener: ftpListener, ListenAddr: s.binding.GetAddress(), PublicIPResolver: s.binding.passiveIPResolver, ActiveTransferPortNon20: s.config.ActiveTransfersPortNon20, IdleTimeout: -1, ConnectionTimeout: 20, Banner: s.statusBanner, TLSRequired: ftpserver.TLSRequirement(s.binding.TLSMode), DisableSite: !s.config.EnableSite, DisableActiveMode: s.config.DisableActiveMode, EnableHASH: s.config.HASHSupport > 0, EnableCOMB: s.config.CombineSupport > 0, DefaultTransferType: ftpserver.TransferTypeBinary, ActiveConnectionsCheck: ftpserver.DataConnectionRequirement(s.binding.ActiveConnectionsSecurity), PasvConnectionsCheck: ftpserver.DataConnectionRequirement(s.binding.PassiveConnectionsSecurity), } if portRange != nil { settings.PassiveTransferPortRange = portRange } return settings, nil } // ClientConnected is called to send the very first welcome message func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { cc.SetDebug(s.binding.Debug) ipAddr := util.GetIPFromRemoteAddress(cc.RemoteAddr().String()) common.Connections.AddClientConnection(ipAddr) if common.IsBanned(ipAddr, common.ProtocolFTP) { logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %q is banned", ipAddr) return "Access denied: banned client IP", common.ErrConnectionDenied } if err := common.Connections.IsNewConnectionAllowed(ipAddr, common.ProtocolFTP); err != nil { logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection not allowed from ip %q: %v", ipAddr, err) return "Access denied", err } _, err := common.LimitRate(common.ProtocolFTP, ipAddr) if err != nil { return fmt.Sprintf("Access denied: %v", err.Error()), err } if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolFTP); err != nil { return "Access denied", err } connID := fmt.Sprintf("%v_%v", s.ID, cc.ID()) user := dataprovider.User{} connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, cc.LocalAddr().String(), cc.RemoteAddr().String(), user), clientContext: cc, } err = common.Connections.Add(connection) return s.initialMsg, err } // ClientDisconnected is called when the user disconnects, even if he never authenticated func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) { connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()) common.Connections.Remove(connID) common.Connections.RemoveClientConnection(util.GetIPFromRemoteAddress(cc.RemoteAddr().String())) } // AuthUser authenticates the user and selects an handling driver func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) (ftpserver.ClientDriver, error) { loginMethod := dataprovider.LoginMethodPassword tlsState, ok := cc.Extra().(*tlsState) if ok && tlsState != nil && tlsState.LoginWithMutualTLS { loginMethod = dataprovider.LoginMethodTLSCertificateAndPwd } ipAddr := util.GetIPFromRemoteAddress(cc.RemoteAddr().String()) user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, common.ProtocolFTP) if err != nil { user.Username = username updateLoginMetrics(&user, ipAddr, loginMethod, err, nil) return nil, dataprovider.ErrInvalidCredentials } connection, err := s.validateUser(user, cc, loginMethod) defer updateLoginMetrics(&user, ipAddr, loginMethod, err, connection) if err != nil { return nil, err } setStartDirectory(user.Filters.StartDirectory, cc) dataprovider.UpdateLastLogin(&user) return connection, nil } // PreAuthUser implements the MainDriverExtensionUserVerifier interface func (s *Server) PreAuthUser(cc ftpserver.ClientContext, username string) error { if s.binding.TLSMode == 0 && s.tlsConfig != nil { user, err := dataprovider.GetFTPPreAuthUser(username, util.GetIPFromRemoteAddress(cc.RemoteAddr().String())) if err == nil { if user.Filters.FTPSecurity == 1 { return cc.SetTLSRequirement(ftpserver.MandatoryEncryption) } return nil } if !errors.Is(err, util.ErrNotFound) { logger.Error(logSender, fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()), "unable to get user on pre auth: %v", err) return common.ErrInternalFailure } } return nil } // WrapPassiveListener implements the MainDriverExtensionPassiveWrapper interface func (s *Server) WrapPassiveListener(listener net.Listener) (net.Listener, error) { if s.binding.HasProxy() { return common.Config.GetProxyListener(listener) } return listener, nil } // VerifyConnection checks whether a user should be authenticated using a client certificate without prompting for a password func (s *Server) VerifyConnection(cc ftpserver.ClientContext, user string, tlsConn *tls.Conn) (ftpserver.ClientDriver, error) { if tlsConn == nil { return nil, nil } state := tlsConn.ConnectionState() cc.SetExtra(&tlsState{ LoginWithMutualTLS: false, Cipher: tls.CipherSuiteName(state.CipherSuite), Version: tls.VersionName(state.Version), KEX: state.CurveID.String(), }) if !s.binding.isMutualTLSEnabled() { return nil, nil } if len(state.PeerCertificates) > 0 { ipAddr := util.GetIPFromRemoteAddress(cc.RemoteAddr().String()) dbUser, err := dataprovider.CheckUserBeforeTLSAuth(user, ipAddr, common.ProtocolFTP, state.PeerCertificates[0]) if err != nil { dbUser.Username = user updateLoginMetrics(&dbUser, ipAddr, dataprovider.LoginMethodTLSCertificate, err, nil) return nil, dataprovider.ErrInvalidCredentials } if dbUser.IsTLSVerificationEnabled() { dbUser, err = dataprovider.CheckUserAndTLSCert(user, ipAddr, common.ProtocolFTP, state.PeerCertificates[0]) if err != nil { return nil, err } cc.SetExtra(&tlsState{ LoginWithMutualTLS: true, Cipher: tls.CipherSuiteName(state.CipherSuite), Version: tls.VersionName(state.Version), KEX: state.CurveID.String(), }) if dbUser.IsLoginMethodAllowed(dataprovider.LoginMethodTLSCertificate, common.ProtocolFTP) { connection, err := s.validateUser(dbUser, cc, dataprovider.LoginMethodTLSCertificate) defer updateLoginMetrics(&dbUser, ipAddr, dataprovider.LoginMethodTLSCertificate, err, connection) if err != nil { return nil, err } setStartDirectory(dbUser.Filters.StartDirectory, cc) dataprovider.UpdateLastLogin(&dbUser) return connection, nil } } } return nil, nil } func (s *Server) buildTLSConfig() { if certMgr != nil { certID := common.DefaultTLSKeyPaidID if getConfigPath(s.binding.CertificateFile, "") != "" && getConfigPath(s.binding.CertificateKeyFile, "") != "" { certID = s.binding.GetAddress() } if !certMgr.HasCertificate(certID) { return } s.tlsConfig = &tls.Config{ GetCertificate: certMgr.GetCertificateFunc(certID), MinVersion: util.GetTLSVersion(s.binding.MinTLSVersion), CipherSuites: s.binding.ciphers, } logger.Debug(logSender, "", "configured TLS cipher suites for binding %q: %v, certID: %v", s.binding.GetAddress(), s.binding.ciphers, certID) if s.binding.isMutualTLSEnabled() { s.tlsConfig.ClientCAs = certMgr.GetRootCAs() s.tlsConfig.VerifyConnection = s.verifyTLSConnection switch s.binding.ClientAuthType { case 1: s.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert case 2: s.tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven } } } } // GetTLSConfig returns the TLS configuration for this server func (s *Server) GetTLSConfig() (*tls.Config, error) { if s.tlsConfig != nil { return s.tlsConfig, nil } return nil, errors.New("no TLS certificate configured") } // VerifyTLSConnectionState implements the MainDriverExtensionTLSConnectionStateVerifier extension func (s *Server) VerifyTLSConnectionState(_ ftpserver.ClientContext, cs tls.ConnectionState) error { if !s.binding.isMutualTLSEnabled() { return nil } return s.verifyTLSConnection(cs) } func (s *Server) verifyTLSConnection(state tls.ConnectionState) error { if certMgr != nil { var clientCrt *x509.Certificate var clientCrtName string if len(state.PeerCertificates) > 0 { clientCrt = state.PeerCertificates[0] clientCrtName = clientCrt.Subject.String() } if len(state.VerifiedChains) == 0 { if s.binding.ClientAuthType == 2 { return nil } logger.Warn(logSender, "", "TLS connection cannot be verified: unable to get verification chain") return errors.New("TLS connection cannot be verified: unable to get verification chain") } for _, verifiedChain := range state.VerifiedChains { var caCrt *x509.Certificate if len(verifiedChain) > 0 { caCrt = verifiedChain[len(verifiedChain)-1] } if certMgr.IsRevoked(clientCrt, caCrt) { logger.Debug(logSender, "", "tls handshake error, client certificate %q has beed revoked", clientCrtName) return common.ErrCrtRevoked } } } return nil } func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext, loginMethod string) (*Connection, error) { connectionID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()) if !filepath.IsAbs(user.HomeDir) { logger.Warn(logSender, connectionID, "user %q has an invalid home dir: %q. Home dir must be an absolute path, login not allowed", user.Username, user.HomeDir) return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir) } if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolFTP) { logger.Info(logSender, connectionID, "cannot login user %q, protocol FTP is not allowed", user.Username) return nil, fmt.Errorf("protocol FTP is not allowed for user %q", user.Username) } if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolFTP) { logger.Info(logSender, connectionID, "cannot login user %q, %v login method is not allowed", user.Username, loginMethod) return nil, fmt.Errorf("login method %v is not allowed for user %q", loginMethod, user.Username) } if user.MustSetSecondFactorForProtocol(common.ProtocolFTP) { logger.Info(logSender, connectionID, "cannot login user %q, second factor authentication is not set", user.Username) return nil, fmt.Errorf("second factor authentication is not set for user %q", user.Username) } if user.MaxSessions > 0 { activeSessions := common.Connections.GetActiveSessions(user.Username) if activeSessions >= user.MaxSessions { logger.Info(logSender, connectionID, "authentication refused for user: %q, too many open sessions: %v/%v", user.Username, activeSessions, user.MaxSessions) return nil, fmt.Errorf("too many open sessions: %v", activeSessions) } } remoteAddr := cc.RemoteAddr().String() if !user.IsLoginFromAddrAllowed(remoteAddr) { logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", user.Username, remoteAddr) return nil, fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, remoteAddr) } err := user.CheckFsRoot(connectionID) if err != nil { errClose := user.CloseFs() logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) return nil, common.ErrInternalFailure } connection := &Connection{ BaseConnection: common.NewBaseConnection(fmt.Sprintf("%v_%v", s.ID, cc.ID()), common.ProtocolFTP, cc.LocalAddr().String(), remoteAddr, user), clientContext: cc, } err = common.Connections.Swap(connection) if err != nil { errClose := user.CloseFs() logger.Warn(logSender, connectionID, "unable to swap connection: %v, close fs error: %v", err, errClose) return nil, err } return connection, nil } func setStartDirectory(startDirectory string, cc ftpserver.ClientContext) { if startDirectory == "" { return } cc.SetPath(startDirectory) } func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error, c *Connection) { metric.AddLoginAttempt(loginMethod) if err == nil { info := "" if tlsState, ok := c.clientContext.Extra().(*tlsState); ok && tlsState != nil { info = fmt.Sprintf("%s - %s - %s", tlsState.Version, tlsState.Cipher, tlsState.KEX) } logger.LoginLog(user.Username, ip, loginMethod, common.ProtocolFTP, c.ID, c.GetClientVersion(), c.clientContext.HasTLSForControl(), info) plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, common.ProtocolFTP, user.Username, ip, "", nil) common.DelayLogin(nil) } else if err != common.ErrInternalFailure { logger.ConnectionFailedLog(user.Username, ip, loginMethod, common.ProtocolFTP, err.Error()) event := common.HostEventLoginFailed logEv := notifier.LogEventTypeLoginFailed if errors.Is(err, util.ErrNotFound) { event = common.HostEventUserNotFound logEv = notifier.LogEventTypeLoginNoUser } common.AddDefenderEvent(ip, common.ProtocolFTP, event) plugin.Handler.NotifyLogEvent(logEv, common.ProtocolFTP, user.Username, ip, "", err) if loginMethod != dataprovider.LoginMethodTLSCertificate { common.DelayLogin(err) } } metric.AddLoginResult(loginMethod, err) dataprovider.ExecutePostLoginHook(user, loginMethod, ip, common.ProtocolFTP, err) } ================================================ FILE: internal/ftpd/transfer.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package ftpd import ( "errors" "io" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/vfs" ) // transfer contains the transfer details for an upload or a download. // It implements the ftpserver.FileTransfer interface to handle files downloads and uploads type transfer struct { *common.BaseTransfer writer io.WriteCloser reader io.ReadCloser isFinished bool expectedOffset int64 } func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader, expectedOffset int64) *transfer { var writer io.WriteCloser var reader io.ReadCloser if baseTransfer.File != nil { writer = baseTransfer.File reader = baseTransfer.File } else if pipeWriter != nil { writer = pipeWriter } else if pipeReader != nil { reader = pipeReader } return &transfer{ BaseTransfer: baseTransfer, writer: writer, reader: reader, isFinished: false, expectedOffset: expectedOffset, } } // Read reads the contents to downloads. func (t *transfer) Read(p []byte) (n int, err error) { t.Connection.UpdateLastActivity() n, err = t.reader.Read(p) t.BytesSent.Add(int64(n)) if err == nil { err = t.CheckRead() } if err != nil && err != io.EOF { t.TransferError(err) err = t.ConvertError(err) return } t.HandleThrottle() return } // Write writes the uploaded contents. func (t *transfer) Write(p []byte) (n int, err error) { t.Connection.UpdateLastActivity() n, err = t.writer.Write(p) t.BytesReceived.Add(int64(n)) if err == nil { err = t.CheckWrite() } if err != nil { t.TransferError(err) err = t.ConvertError(err) return } t.HandleThrottle() return } // Seek sets the offset to resume an upload or a download func (t *transfer) Seek(offset int64, whence int) (int64, error) { t.Connection.UpdateLastActivity() if t.File != nil { ret, err := t.File.Seek(offset, whence) if err != nil { t.TransferError(err) } return ret, err } if (t.reader != nil || t.writer != nil) && t.expectedOffset == offset && whence == io.SeekStart { return offset, nil } t.TransferError(errors.New("seek is unsupported for this transfer")) return 0, common.ErrOpUnsupported } // Close it is called when the transfer is completed. func (t *transfer) Close() error { if err := t.setFinished(); err != nil { return err } err := t.closeIO() errBaseClose := t.BaseTransfer.Close() if errBaseClose != nil { err = errBaseClose } return t.Connection.GetFsError(t.Fs, err) } func (t *transfer) closeIO() error { var err error if t.File != nil { err = t.File.Close() } else if t.writer != nil { err = t.writer.Close() t.Lock() // we set ErrTransfer here so quota is not updated, in this case the uploads are atomic if err != nil && t.ErrTransfer == nil { t.ErrTransfer = err } t.Unlock() } else if t.reader != nil { err = t.reader.Close() if metadater, ok := t.reader.(vfs.Metadater); ok { t.SetMetadata(metadater.Metadata()) } } return err } func (t *transfer) setFinished() error { t.Lock() defer t.Unlock() if t.isFinished { return common.ErrTransferClosed } t.isFinished = true return nil } ================================================ FILE: internal/httpclient/httpclient.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package httpclient provides HTTP client configuration for SFTPGo hooks package httpclient import ( "crypto/tls" "crypto/x509" "fmt" "io" "net/http" "os" "path/filepath" "strings" "time" "github.com/hashicorp/go-retryablehttp" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) // TLSKeyPair defines the paths for a TLS key pair type TLSKeyPair struct { Cert string `json:"cert" mapstructure:"cert"` Key string `json:"key" mapstructure:"key"` } // Header defines an HTTP header. // If the URL is not empty, the header is added only if the // requested URL starts with the one specified type Header struct { Key string `json:"key" mapstructure:"key"` Value string `json:"value" mapstructure:"value"` URL string `json:"url" mapstructure:"url"` } // Config defines the configuration for HTTP clients. // HTTP clients are used for executing hooks such as the ones used for // custom actions, external authentication and pre-login user modifications type Config struct { // Timeout specifies a time limit, in seconds, for a request Timeout float64 `json:"timeout" mapstructure:"timeout"` // RetryWaitMin defines the minimum waiting time between attempts in seconds RetryWaitMin int `json:"retry_wait_min" mapstructure:"retry_wait_min"` // RetryWaitMax defines the minimum waiting time between attempts in seconds RetryWaitMax int `json:"retry_wait_max" mapstructure:"retry_wait_max"` // RetryMax defines the maximum number of attempts RetryMax int `json:"retry_max" mapstructure:"retry_max"` // CACertificates defines extra CA certificates to trust. // The paths can be absolute or relative to the config dir. // Adding trusted CA certificates is a convenient way to use self-signed // certificates without defeating the purpose of using TLS CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"` // Certificates defines the certificates to use for mutual TLS Certificates []TLSKeyPair `json:"certificates" mapstructure:"certificates"` // if enabled the HTTP client accepts any TLS certificate presented by // the server and any host name in that certificate. // In this mode, TLS is susceptible to man-in-the-middle attacks. // This should be used only for testing. SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"` // Headers defines a list of http headers to add to each request Headers []Header `json:"headers" mapstructure:"headers"` customTransport *http.Transport } const logSender = "httpclient" var httpConfig Config // Initialize configures HTTP clients func (c *Config) Initialize(configDir string) error { if c.Timeout <= 0 { return fmt.Errorf("invalid timeout: %v", c.Timeout) } rootCAs, err := c.loadCACerts(configDir) if err != nil { return err } customTransport := http.DefaultTransport.(*http.Transport).Clone() if customTransport.TLSClientConfig != nil { customTransport.TLSClientConfig.RootCAs = rootCAs } else { customTransport.TLSClientConfig = &tls.Config{ RootCAs: rootCAs, } } customTransport.TLSClientConfig.InsecureSkipVerify = c.SkipTLSVerify c.customTransport = customTransport err = c.loadCertificates(configDir) if err != nil { return err } var headers []Header for _, h := range c.Headers { if h.Key != "" && h.Value != "" { headers = append(headers, h) } } c.Headers = headers httpConfig = *c return nil } // loadCACerts returns system cert pools and try to add the configured // CA certificates to it func (c *Config) loadCACerts(configDir string) (*x509.CertPool, error) { if len(c.CACertificates) == 0 { return nil, nil } rootCAs, err := x509.SystemCertPool() if err != nil { rootCAs = x509.NewCertPool() } for _, ca := range c.CACertificates { if !util.IsFileInputValid(ca) { return nil, fmt.Errorf("unable to load invalid CA certificate: %q", ca) } if !filepath.IsAbs(ca) { ca = filepath.Join(configDir, ca) } certs, err := os.ReadFile(ca) if err != nil { return nil, fmt.Errorf("unable to load CA certificate: %v", err) } if rootCAs.AppendCertsFromPEM(certs) { logger.Debug(logSender, "", "CA certificate %q added to the trusted certificates", ca) } else { return nil, fmt.Errorf("unable to add CA certificate %q to the trusted cetificates", ca) } } return rootCAs, nil } func (c *Config) loadCertificates(configDir string) error { if len(c.Certificates) == 0 { return nil } for _, keyPair := range c.Certificates { cert := keyPair.Cert key := keyPair.Key if !util.IsFileInputValid(cert) { return fmt.Errorf("unable to load invalid certificate: %q", cert) } if !util.IsFileInputValid(key) { return fmt.Errorf("unable to load invalid key: %q", key) } if !filepath.IsAbs(cert) { cert = filepath.Join(configDir, cert) } if !filepath.IsAbs(key) { key = filepath.Join(configDir, key) } tlsCert, err := tls.LoadX509KeyPair(cert, key) if err != nil { return fmt.Errorf("unable to load key pair %q, %q: %v", cert, key, err) } x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0]) if err == nil { logger.Debug(logSender, "", "adding leaf certificate for key pair %q, %q", cert, key) tlsCert.Leaf = x509Cert } logger.Debug(logSender, "", "client certificate %q and key %q successfully loaded", cert, key) c.customTransport.TLSClientConfig.Certificates = append(c.customTransport.TLSClientConfig.Certificates, tlsCert) } return nil } // GetHTTPClient returns a new HTTP client with the configured parameters func GetHTTPClient() *http.Client { return &http.Client{ Timeout: time.Duration(httpConfig.Timeout * float64(time.Second)), Transport: httpConfig.customTransport, } } // GetRetraybleHTTPClient returns an HTTP client that retry a request on error. // It uses the configured retry parameters func GetRetraybleHTTPClient() *retryablehttp.Client { client := retryablehttp.NewClient() client.HTTPClient.Timeout = time.Duration(httpConfig.Timeout * float64(time.Second)) client.HTTPClient.Transport.(*http.Transport).TLSClientConfig = httpConfig.customTransport.TLSClientConfig client.Logger = &logger.LeveledLogger{Sender: "RetryableHTTPClient"} client.RetryWaitMin = time.Duration(httpConfig.RetryWaitMin) * time.Second client.RetryWaitMax = time.Duration(httpConfig.RetryWaitMax) * time.Second client.RetryMax = httpConfig.RetryMax return client } // Get issues a GET to the specified URL func Get(url string) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err } addHeaders(req, url) client := GetHTTPClient() defer client.CloseIdleConnections() return client.Do(req) } // Post issues a POST to the specified URL func Post(url string, contentType string, body io.Reader) (*http.Response, error) { req, err := http.NewRequest(http.MethodPost, url, body) if err != nil { return nil, err } req.Header.Set("Content-Type", contentType) addHeaders(req, url) client := GetHTTPClient() defer client.CloseIdleConnections() return client.Do(req) } // RetryableGet issues a GET to the specified URL using the retryable client func RetryableGet(url string) (*http.Response, error) { req, err := retryablehttp.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err } addHeadersToRetryableReq(req, url) client := GetRetraybleHTTPClient() defer client.HTTPClient.CloseIdleConnections() return client.Do(req) } // RetryablePost issues a POST to the specified URL using the retryable client func RetryablePost(url string, contentType string, body io.Reader) (*http.Response, error) { req, err := retryablehttp.NewRequest(http.MethodPost, url, body) if err != nil { return nil, err } req.Header.Set("Content-Type", contentType) addHeadersToRetryableReq(req, url) client := GetRetraybleHTTPClient() defer client.HTTPClient.CloseIdleConnections() return client.Do(req) } func addHeaders(req *http.Request, url string) { for idx := range httpConfig.Headers { h := &httpConfig.Headers[idx] if h.URL == "" || strings.HasPrefix(url, h.URL) { req.Header.Set(h.Key, h.Value) } } } func addHeadersToRetryableReq(req *retryablehttp.Request, url string) { for idx := range httpConfig.Headers { h := &httpConfig.Headers[idx] if h.URL == "" || strings.HasPrefix(url, h.URL) { req.Header.Set(h.Key, h.Value) } } } ================================================ FILE: internal/httpd/api_admin.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "errors" "fmt" "net/http" "net/url" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" ) func getAdmins(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } admins, err := dataprovider.GetAdmins(limit, offset, order) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } render.JSON(w, r, admins) } func getAdminByUsername(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) username := getURLParam(r, "username") renderAdmin(w, r, username, http.StatusOK) } func renderAdmin(w http.ResponseWriter, r *http.Request, username string, status int) { admin, err := dataprovider.AdminExists(username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } admin.HideConfidentialData() if status != http.StatusOK { ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusCreated) render.JSON(w, r.WithContext(ctx), admin) } else { render.JSON(w, r, admin) } } func addAdmin(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } admin := dataprovider.Admin{} err = render.DecodeJSON(r.Body, &admin) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } err = dataprovider.AddAdmin(&admin, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Add("Location", fmt.Sprintf("%s/%s", adminPath, url.PathEscape(admin.Username))) renderAdmin(w, r, admin.Username, http.StatusCreated) } func disableAdmin2FA(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } admin, err := dataprovider.AdminExists(getURLParam(r, "username")) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !admin.Filters.TOTPConfig.Enabled { sendAPIResponse(w, r, nil, "two-factor authentication is not enabled", http.StatusBadRequest) return } if admin.Username == claims.Username { if admin.Filters.RequireTwoFactor { err := util.NewValidationError("two-factor authentication must be enabled") sendAPIResponse(w, r, err, "", getRespStatus(err)) return } } admin.Filters.RecoveryCodes = nil admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: false, } if err := dataprovider.UpdateAdmin(&admin, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "2FA disabled", http.StatusOK) } func updateAdmin(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) username := getURLParam(r, "username") admin, err := dataprovider.AdminExists(username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedAdmin dataprovider.Admin err = render.DecodeJSON(r.Body, &updatedAdmin) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } if username == claims.Username { if claims.APIKeyID != "" { sendAPIResponse(w, r, errors.New("updating the admin impersonated with an API key is not allowed"), "", http.StatusBadRequest) return } if !util.SlicesEqual(admin.Permissions, updatedAdmin.Permissions) { sendAPIResponse(w, r, errors.New("you cannot change your permissions"), "", http.StatusBadRequest) return } if updatedAdmin.Status == 0 { sendAPIResponse(w, r, errors.New("you cannot disable yourself"), "", http.StatusBadRequest) return } if updatedAdmin.Role != claims.Role { sendAPIResponse(w, r, errors.New("you cannot add/change your role"), "", http.StatusBadRequest) return } updatedAdmin.Filters.RequirePasswordChange = admin.Filters.RequirePasswordChange updatedAdmin.Filters.RequireTwoFactor = admin.Filters.RequireTwoFactor } updatedAdmin.ID = admin.ID updatedAdmin.Username = admin.Username if updatedAdmin.Password == "" { updatedAdmin.Password = admin.Password } updatedAdmin.Filters.TOTPConfig = admin.Filters.TOTPConfig updatedAdmin.Filters.RecoveryCodes = admin.Filters.RecoveryCodes err = dataprovider.UpdateAdmin(&updatedAdmin, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Admin updated", http.StatusOK) } func deleteAdmin(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) username := getURLParam(r, "username") claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } if username == claims.Username { sendAPIResponse(w, r, errors.New("you cannot delete yourself"), "", http.StatusBadRequest) return } err = dataprovider.DeleteAdmin(username, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Admin deleted", http.StatusOK) } func getAdminProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } admin, err := dataprovider.AdminExists(claims.Username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } resp := adminProfile{ baseProfile: baseProfile{ Email: admin.Email, Description: admin.Description, AllowAPIKeyAuth: admin.Filters.AllowAPIKeyAuth, }, } render.JSON(w, r, resp) } func updateAdminProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } admin, err := dataprovider.AdminExists(claims.Username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var req adminProfile err = render.DecodeJSON(r.Body, &req) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } admin.Email = req.Email admin.Description = req.Description admin.Filters.AllowAPIKeyAuth = req.AllowAPIKeyAuth if err := dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Profile updated", http.StatusOK) } func forgotAdminPassword(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) if !smtp.IsEnabled() { sendAPIResponse(w, r, nil, "No SMTP configuration", http.StatusBadRequest) return } err := handleForgotPassword(r, getURLParam(r, "username"), true) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Check your email for the confirmation code", http.StatusOK) } func resetAdminPassword(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var req pwdReset err := render.DecodeJSON(r.Body, &req) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } _, _, err = handleResetPassword(r, req.Code, req.Password, req.Password, true) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Password reset successful", http.StatusOK) } func changeAdminPassword(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var pwd pwdChange err := render.DecodeJSON(r.Body, &pwd) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } err = doChangeAdminPassword(r, pwd.CurrentPassword, pwd.NewPassword, pwd.NewPassword) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } invalidateToken(r) sendAPIResponse(w, r, err, "Password updated", http.StatusOK) } func doChangeAdminPassword(r *http.Request, currentPassword, newPassword, confirmNewPassword string) error { if currentPassword == "" || newPassword == "" || confirmNewPassword == "" { return util.NewI18nError( util.NewValidationError("please provide the current password and the new one two times"), util.I18nErrorChangePwdRequiredFields, ) } if newPassword != confirmNewPassword { return util.NewI18nError(util.NewValidationError("the two password fields do not match"), util.I18nErrorChangePwdNoMatch) } if currentPassword == newPassword { return util.NewI18nError( util.NewValidationError("the new password must be different from the current one"), util.I18nErrorChangePwdNoDifferent, ) } claims, err := jwt.FromContext(r.Context()) if err != nil { return util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken) } admin, err := dataprovider.AdminExists(claims.Username) if err != nil { return err } match, err := admin.CheckPassword(currentPassword) if !match || err != nil { return util.NewI18nError(util.NewValidationError("current password does not match"), util.I18nErrorChangePwdCurrentNoMatch) } admin.Password = newPassword admin.Filters.RequirePasswordChange = false return dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) } ================================================ FILE: internal/httpd/api_configs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "net/http" "github.com/go-chi/render" "github.com/rs/xid" "golang.org/x/oauth2" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" ) type smtpTestRequest struct { smtp.Config Recipient string `json:"recipient"` } func (r *smtpTestRequest) hasRedactedSecret() bool { return r.Password == redactedSecret || r.OAuth2.ClientSecret == redactedSecret || r.OAuth2.RefreshToken == redactedSecret } func testSMTPConfig(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var req smtpTestRequest err := render.DecodeJSON(r.Body, &req) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if req.hasRedactedSecret() { configs, err := dataprovider.GetConfigs() if err != nil { sendAPIResponse(w, r, err, "", http.StatusInternalServerError) return } configs.SetNilsToEmpty() if err := configs.SMTP.TryDecrypt(); err == nil { if req.Password == redactedSecret { req.Password = configs.SMTP.Password.GetPayload() } if req.OAuth2.ClientSecret == redactedSecret { req.OAuth2.ClientSecret = configs.SMTP.OAuth2.ClientSecret.GetPayload() } if req.OAuth2.RefreshToken == redactedSecret { req.OAuth2.RefreshToken = configs.SMTP.OAuth2.RefreshToken.GetPayload() } } } if req.AuthType == 3 { if err := req.OAuth2.Validate(); err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } } if err := req.SendEmail([]string{req.Recipient}, nil, "SFTPGo - Testing Email Settings", "It appears your SFTPGo email is setup correctly!", smtp.EmailContentTypeTextPlain); err != nil { logger.Info(logSender, "", "unable to send test email: %v", err) sendAPIResponse(w, r, err, "", http.StatusInternalServerError) return } sendAPIResponse(w, r, nil, "SMTP connection OK", http.StatusOK) } type oauth2TokenRequest struct { smtp.OAuth2Config BaseRedirectURL string `json:"base_redirect_url"` } func (s *httpdServer) handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var req oauth2TokenRequest err := render.DecodeJSON(r.Body, &req) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if req.BaseRedirectURL == "" { sendAPIResponse(w, r, nil, "base redirect url is required", http.StatusBadRequest) return } if req.ClientSecret == redactedSecret { configs, err := dataprovider.GetConfigs() if err != nil { sendAPIResponse(w, r, err, "", http.StatusInternalServerError) return } configs.SetNilsToEmpty() if err := configs.SMTP.TryDecrypt(); err == nil { req.ClientSecret = configs.SMTP.OAuth2.ClientSecret.GetPayload() } } cfg := req.GetOAuth2() cfg.RedirectURL = req.BaseRedirectURL + webOAuth2RedirectPath clientSecret := kms.NewPlainSecret(cfg.ClientSecret) clientSecret.SetAdditionalData(xid.New().String()) pendingAuth := newOAuth2PendingAuth(req.Provider, cfg.RedirectURL, cfg.ClientID, clientSecret) oauth2Mgr.addPendingAuth(pendingAuth) stateToken := createOAuth2Token(s.csrfTokenAuth, pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr)) if stateToken == "" { sendAPIResponse(w, r, nil, "unable to create state token", http.StatusInternalServerError) return } u := cfg.AuthCodeURL(stateToken, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(pendingAuth.Verifier)) sendAPIResponse(w, r, nil, u, http.StatusOK) } ================================================ FILE: internal/httpd/api_defender.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/hex" "errors" "fmt" "net" "net/http" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/util" ) func getDefenderHosts(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) hosts, err := common.GetDefenderHosts() if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if hosts == nil { render.JSON(w, r, make([]dataprovider.DefenderEntry, 0)) return } render.JSON(w, r, hosts) } func getDefenderHostByID(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) ip, err := getIPFromID(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } host, err := common.GetDefenderHost(ip) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } render.JSON(w, r, host) } func deleteDefenderHostByID(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) ip, err := getIPFromID(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if !common.DeleteDefenderHost(ip) { sendAPIResponse(w, r, nil, "Not found", http.StatusNotFound) return } sendAPIResponse(w, r, nil, "OK", http.StatusOK) } func getIPFromID(r *http.Request) (string, error) { decoded, err := hex.DecodeString(getURLParam(r, "id")) if err != nil { return "", errors.New("invalid host id") } ip := util.BytesToString(decoded) err = validateIPAddress(ip) if err != nil { return "", err } return ip, nil } func validateIPAddress(ip string) error { if net.ParseIP(ip) == nil { return fmt.Errorf("ip address %q is not valid", ip) } return nil } ================================================ FILE: internal/httpd/api_eventrule.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "fmt" "net/http" "net/url" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) func getEventActions(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } actions, err := dataprovider.GetEventActions(limit, offset, order, false) if err != nil { sendAPIResponse(w, r, err, "", http.StatusInternalServerError) return } render.JSON(w, r, actions) } func renderEventAction(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { action, err := dataprovider.EventActionExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if hideConfidentialData(claims, r) { action.PrepareForRendering() } if status != http.StatusOK { ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) render.JSON(w, r.WithContext(ctx), action) } else { render.JSON(w, r, action) } } func getEventActionByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") renderEventAction(w, r, name, claims, http.StatusOK) } func addEventAction(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var action dataprovider.BaseEventAction err = render.DecodeJSON(r.Body, &action) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err = dataprovider.AddEventAction(&action, claims.Username, ipAddr, claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Add("Location", fmt.Sprintf("%s/%s", eventActionsPath, url.PathEscape(action.Name))) renderEventAction(w, r, action.Name, claims, http.StatusCreated) } func updateEventAction(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") action, err := dataprovider.EventActionExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedAction dataprovider.BaseEventAction err = render.DecodeJSON(r.Body, &updatedAction) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedAction.ID = action.ID updatedAction.Name = action.Name updatedAction.Options.SetEmptySecretsIfNil() switch updatedAction.Type { case dataprovider.ActionTypeHTTP: if updatedAction.Options.HTTPConfig.Password.IsNotPlainAndNotEmpty() { updatedAction.Options.HTTPConfig.Password = action.Options.HTTPConfig.Password } } err = dataprovider.UpdateEventAction(&updatedAction, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Event action updated", http.StatusOK) } func deleteEventAction(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") err = dataprovider.DeleteEventAction(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Event action deleted", http.StatusOK) } func getEventRules(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } rules, err := dataprovider.GetEventRules(limit, offset, order) if err != nil { sendAPIResponse(w, r, err, "", http.StatusInternalServerError) return } render.JSON(w, r, rules) } func renderEventRule(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { rule, err := dataprovider.EventRuleExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if hideConfidentialData(claims, r) { rule.PrepareForRendering() } if status != http.StatusOK { ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) render.JSON(w, r.WithContext(ctx), rule) } else { render.JSON(w, r, rule) } } func getEventRuleByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") renderEventRule(w, r, name, claims, http.StatusOK) } func addEventRule(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var rule dataprovider.EventRule err = render.DecodeJSON(r.Body, &rule) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := dataprovider.AddEventRule(&rule, claims.Username, ipAddr, claims.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Add("Location", fmt.Sprintf("%s/%s", eventRulesPath, url.PathEscape(rule.Name))) renderEventRule(w, r, rule.Name, claims, http.StatusCreated) } func updateEventRule(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } rule, err := dataprovider.EventRuleExists(getURLParam(r, "name")) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedRule dataprovider.EventRule err = render.DecodeJSON(r.Body, &updatedRule) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedRule.ID = rule.ID updatedRule.Name = rule.Name err = dataprovider.UpdateEventRule(&updatedRule, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Event rules updated", http.StatusOK) } func deleteEventRule(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") err = dataprovider.DeleteEventRule(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Event rule deleted", http.StatusOK) } func runOnDemandRule(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) name := getURLParam(r, "name") if err := common.RunOnDemandRule(name); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Event rule started", http.StatusAccepted) } ================================================ FILE: internal/httpd/api_events.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/csv" "encoding/json" "fmt" "net/http" "strconv" "strings" "time" "github.com/sftpgo/sdk/plugin/eventsearcher" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" ) func getCommonSearchParamsFromRequest(r *http.Request) (eventsearcher.CommonSearchParams, error) { c := eventsearcher.CommonSearchParams{} c.Limit = 100 if _, ok := r.URL.Query()["limit"]; ok { limit, err := strconv.Atoi(r.URL.Query().Get("limit")) if err != nil { return c, util.NewValidationError(fmt.Sprintf("invalid limit: %v", err)) } if limit < 1 || limit > 1000 { return c, util.NewValidationError(fmt.Sprintf("limit is out of the 1-1000 range: %v", limit)) } c.Limit = limit } if _, ok := r.URL.Query()["order"]; ok { order := r.URL.Query().Get("order") if order != dataprovider.OrderASC && order != dataprovider.OrderDESC { return c, util.NewValidationError(fmt.Sprintf("invalid order %q", order)) } if order == dataprovider.OrderASC { c.Order = 1 } } if _, ok := r.URL.Query()["start_timestamp"]; ok { ts, err := strconv.ParseInt(r.URL.Query().Get("start_timestamp"), 10, 64) if err != nil { return c, util.NewValidationError(fmt.Sprintf("invalid start_timestamp: %v", err)) } c.StartTimestamp = ts } if _, ok := r.URL.Query()["end_timestamp"]; ok { ts, err := strconv.ParseInt(r.URL.Query().Get("end_timestamp"), 10, 64) if err != nil { return c, util.NewValidationError(fmt.Sprintf("invalid end_timestamp: %v", err)) } c.EndTimestamp = ts } c.Username = strings.TrimSpace(r.URL.Query().Get("username")) c.IP = strings.TrimSpace(r.URL.Query().Get("ip")) c.InstanceIDs = getCommaSeparatedQueryParam(r, "instance_ids") c.FromID = r.URL.Query().Get("from_id") return c, nil } func getFsSearchParamsFromRequest(r *http.Request) (eventsearcher.FsEventSearch, error) { var err error s := eventsearcher.FsEventSearch{} s.CommonSearchParams, err = getCommonSearchParamsFromRequest(r) if err != nil { return s, err } s.FsProvider = -1 if _, ok := r.URL.Query()["fs_provider"]; ok { provider := r.URL.Query().Get("fs_provider") val, err := strconv.Atoi(provider) if err != nil { return s, util.NewValidationError(fmt.Sprintf("invalid fs_provider: %v", provider)) } s.FsProvider = val } s.Actions = getCommaSeparatedQueryParam(r, "actions") s.SSHCmd = strings.TrimSpace(r.URL.Query().Get("ssh_cmd")) s.Bucket = strings.TrimSpace(r.URL.Query().Get("bucket")) s.Endpoint = strings.TrimSpace(r.URL.Query().Get("endpoint")) s.Protocols = getCommaSeparatedQueryParam(r, "protocols") statuses := getCommaSeparatedQueryParam(r, "statuses") for _, status := range statuses { val, err := strconv.ParseInt(status, 10, 32) if err != nil { return s, util.NewValidationError(fmt.Sprintf("invalid status: %v", status)) } s.Statuses = append(s.Statuses, int32(val)) } return s, nil } func getProviderSearchParamsFromRequest(r *http.Request) (eventsearcher.ProviderEventSearch, error) { var err error s := eventsearcher.ProviderEventSearch{} s.CommonSearchParams, err = getCommonSearchParamsFromRequest(r) if err != nil { return s, err } s.Actions = getCommaSeparatedQueryParam(r, "actions") s.ObjectName = strings.TrimSpace(r.URL.Query().Get("object_name")) s.ObjectTypes = getCommaSeparatedQueryParam(r, "object_types") return s, nil } func getLogSearchParamsFromRequest(r *http.Request) (eventsearcher.LogEventSearch, error) { var err error s := eventsearcher.LogEventSearch{} s.CommonSearchParams, err = getCommonSearchParamsFromRequest(r) if err != nil { return s, err } s.Protocols = getCommaSeparatedQueryParam(r, "protocols") events := getCommaSeparatedQueryParam(r, "events") for _, ev := range events { evType, err := strconv.ParseInt(ev, 10, 32) if err == nil { s.Events = append(s.Events, int32(evType)) } } return s, nil } func searchFsEvents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } filters, err := getFsSearchParamsFromRequest(r) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } filters.Role = getRoleFilterForEventSearch(r, claims.Role) if getBoolQueryParam(r, "csv_export") { filters.Limit = 100 if err := exportFsEvents(w, &filters); err != nil { panic(http.ErrAbortHandler) } return } data, err := plugin.Handler.SearchFsEvents(&filters) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Set("Content-Type", "application/json") w.Write(data) //nolint:errcheck } func searchProviderEvents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var filters eventsearcher.ProviderEventSearch if filters, err = getProviderSearchParamsFromRequest(r); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } filters.Role = getRoleFilterForEventSearch(r, claims.Role) filters.OmitObjectData = getBoolQueryParam(r, "omit_object_data") if getBoolQueryParam(r, "csv_export") { filters.Limit = 100 filters.OmitObjectData = true if err := exportProviderEvents(w, &filters); err != nil { panic(http.ErrAbortHandler) } return } data, err := plugin.Handler.SearchProviderEvents(&filters) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Set("Content-Type", "application/json") w.Write(data) //nolint:errcheck } func searchLogEvents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var filters eventsearcher.LogEventSearch if filters, err = getLogSearchParamsFromRequest(r); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } filters.Role = getRoleFilterForEventSearch(r, claims.Role) if getBoolQueryParam(r, "csv_export") { filters.Limit = 100 if err := exportLogEvents(w, &filters); err != nil { panic(http.ErrAbortHandler) } return } data, err := plugin.Handler.SearchLogEvents(&filters) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Set("Content-Type", "application/json") w.Write(data) //nolint:errcheck } func exportFsEvents(w http.ResponseWriter, filters *eventsearcher.FsEventSearch) error { w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=fslogs-%s.csv", time.Now().Format("2006-01-02T15-04-05"))) w.Header().Set("Content-Type", "text/csv") w.Header().Set("Accept-Ranges", "none") w.WriteHeader(http.StatusOK) csvWriter := csv.NewWriter(w) ev := fsEvent{} err := csvWriter.Write(ev.getCSVHeader()) if err != nil { return err } results := make([]fsEvent, 0, filters.Limit) for { data, err := plugin.Handler.SearchFsEvents(filters) if err != nil { return err } if err := json.Unmarshal(data, &results); err != nil { return err } for _, event := range results { if err := csvWriter.Write(event.getCSVData()); err != nil { return err } } if len(results) == 0 || len(results) < filters.Limit { break } filters.StartTimestamp = results[len(results)-1].Timestamp filters.FromID = results[len(results)-1].ID results = nil } csvWriter.Flush() return csvWriter.Error() } func exportProviderEvents(w http.ResponseWriter, filters *eventsearcher.ProviderEventSearch) error { w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=providerlogs-%s.csv", time.Now().Format("2006-01-02T15-04-05"))) w.Header().Set("Content-Type", "text/csv") w.Header().Set("Accept-Ranges", "none") w.WriteHeader(http.StatusOK) ev := providerEvent{} csvWriter := csv.NewWriter(w) err := csvWriter.Write(ev.getCSVHeader()) if err != nil { return err } results := make([]providerEvent, 0, filters.Limit) for { data, err := plugin.Handler.SearchProviderEvents(filters) if err != nil { return err } if err := json.Unmarshal(data, &results); err != nil { return err } for _, event := range results { if err := csvWriter.Write(event.getCSVData()); err != nil { return err } } if len(results) < filters.Limit || len(results) == 0 { break } filters.FromID = results[len(results)-1].ID filters.StartTimestamp = results[len(results)-1].Timestamp results = nil } csvWriter.Flush() return csvWriter.Error() } func exportLogEvents(w http.ResponseWriter, filters *eventsearcher.LogEventSearch) error { w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.csv", time.Now().Format("2006-01-02T15-04-05"))) w.Header().Set("Content-Type", "text/csv") w.Header().Set("Accept-Ranges", "none") w.WriteHeader(http.StatusOK) ev := logEvent{} csvWriter := csv.NewWriter(w) err := csvWriter.Write(ev.getCSVHeader()) if err != nil { return err } results := make([]logEvent, 0, filters.Limit) for { data, err := plugin.Handler.SearchLogEvents(filters) if err != nil { return err } if err := json.Unmarshal(data, &results); err != nil { return err } for _, event := range results { if err := csvWriter.Write(event.getCSVData()); err != nil { return err } } if len(results) == 0 || len(results) < filters.Limit { break } filters.StartTimestamp = results[len(results)-1].Timestamp filters.FromID = results[len(results)-1].ID results = nil } csvWriter.Flush() return csvWriter.Error() } func getRoleFilterForEventSearch(r *http.Request, defaultValue string) string { if defaultValue != "" { return defaultValue } return r.URL.Query().Get("role") } type fsEvent struct { ID string `json:"id"` Timestamp int64 `json:"timestamp"` Action string `json:"action"` Username string `json:"username"` FsPath string `json:"fs_path"` FsTargetPath string `json:"fs_target_path,omitempty"` VirtualPath string `json:"virtual_path"` VirtualTargetPath string `json:"virtual_target_path,omitempty"` SSHCmd string `json:"ssh_cmd,omitempty"` FileSize int64 `json:"file_size,omitempty"` Elapsed int64 `json:"elapsed,omitempty"` Status int `json:"status"` Protocol string `json:"protocol"` IP string `json:"ip,omitempty"` SessionID string `json:"session_id"` FsProvider int `json:"fs_provider"` Bucket string `json:"bucket,omitempty"` Endpoint string `json:"endpoint,omitempty"` OpenFlags int `json:"open_flags,omitempty"` Role string `json:"role,omitempty"` InstanceID string `json:"instance_id,omitempty"` } func (e *fsEvent) getCSVHeader() []string { return []string{"Time", "Action", "Path", "Size", "Elapsed", "Status", "User", "Protocol", "IP", "SSH command"} } func (e *fsEvent) getCSVData() []string { timestamp := time.Unix(0, e.Timestamp).UTC() var pathInfo strings.Builder pathInfo.Write([]byte(e.VirtualPath)) if e.VirtualTargetPath != "" { pathInfo.WriteString(" => ") pathInfo.WriteString(e.VirtualTargetPath) } var status string switch e.Status { case 1: status = "OK" case 2: status = "KO" case 3: status = "Quota exceeded" } var fileSize string if e.FileSize > 0 { fileSize = util.ByteCountIEC(e.FileSize) } var elapsed string if e.Elapsed > 0 { elapsed = (time.Duration(e.Elapsed) * time.Millisecond).String() } return []string{timestamp.Format(time.RFC3339Nano), e.Action, pathInfo.String(), fileSize, elapsed, status, e.Username, e.Protocol, e.IP, e.SSHCmd} } type providerEvent struct { ID string `json:"id"` Timestamp int64 `json:"timestamp"` Action string `json:"action"` Username string `json:"username"` IP string `json:"ip,omitempty"` ObjectType string `json:"object_type"` ObjectName string `json:"object_name"` ObjectData []byte `json:"object_data"` Role string `json:"role,omitempty"` InstanceID string `json:"instance_id,omitempty"` } func (e *providerEvent) getCSVHeader() []string { return []string{"Time", "Action", "Object Type", "Object Name", "User", "IP"} } func (e *providerEvent) getCSVData() []string { timestamp := time.Unix(0, e.Timestamp).UTC() return []string{timestamp.Format(time.RFC3339Nano), e.Action, e.ObjectType, e.ObjectName, e.Username, e.IP} } type logEvent struct { ID string `json:"id"` Timestamp int64 `json:"timestamp"` Event int `json:"event"` Protocol string `json:"protocol"` Username string `json:"username,omitempty"` IP string `json:"ip,omitempty"` Message string `json:"message,omitempty"` Role string `json:"role,omitempty"` } func (e *logEvent) getCSVHeader() []string { return []string{"Time", "Event", "Protocol", "User", "IP", "Message"} } func (e *logEvent) getCSVData() []string { timestamp := time.Unix(0, e.Timestamp).UTC() return []string{timestamp.Format(time.RFC3339Nano), getLogEventString(notifier.LogEventType(e.Event)), e.Protocol, e.Username, e.IP, e.Message} } func getLogEventString(event notifier.LogEventType) string { switch event { case notifier.LogEventTypeLoginFailed: return "Login failed" case notifier.LogEventTypeLoginNoUser: return "Login with non-existent user" case notifier.LogEventTypeNoLoginTried: return "No login tried" case notifier.LogEventTypeNotNegotiated: return "Algorithm negotiation failed" case notifier.LogEventTypeLoginOK: return "Login succeeded" default: return "" } } ================================================ FILE: internal/httpd/api_folder.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "fmt" "net/http" "net/url" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) func getFolders(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } folders, err := dataprovider.GetFolders(limit, offset, order, false) if err != nil { sendAPIResponse(w, r, err, "", http.StatusInternalServerError) return } render.JSON(w, r, folders) } func addFolder(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var folder vfs.BaseVirtualFolder err = render.DecodeJSON(r.Body, &folder) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if err := dataprovider.AddFolder(&folder, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Add("Location", fmt.Sprintf("%s/%s", folderPath, url.PathEscape(folder.Name))) renderFolder(w, r, folder.Name, claims, http.StatusCreated) } func updateFolder(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") folder, err := dataprovider.GetFolderByName(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedFolder vfs.BaseVirtualFolder err = render.DecodeJSON(r.Body, &updatedFolder) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedFolder.ID = folder.ID updatedFolder.Name = folder.Name updatedFolder.FsConfig.SetEmptySecretsIfNil() updateEncryptedSecrets(&updatedFolder.FsConfig, &folder.FsConfig) err = dataprovider.UpdateFolder(&updatedFolder, folder.Users, folder.Groups, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Folder updated", http.StatusOK) } func renderFolder(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { folder, err := dataprovider.GetFolderByName(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if hideConfidentialData(claims, r) { folder.PrepareForRendering() } if status != http.StatusOK { ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) render.JSON(w, r.WithContext(ctx), folder) } else { render.JSON(w, r, folder) } } func getFolderByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") renderFolder(w, r, name, claims, http.StatusOK) } func deleteFolder(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") err = dataprovider.DeleteFolder(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Folder deleted", http.StatusOK) } ================================================ FILE: internal/httpd/api_group.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "fmt" "net/http" "net/url" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) func getGroups(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } groups, err := dataprovider.GetGroups(limit, offset, order, false) if err != nil { sendAPIResponse(w, r, err, "", http.StatusInternalServerError) return } render.JSON(w, r, groups) } func addGroup(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var group dataprovider.Group err = render.DecodeJSON(r.Body, &group) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } err = dataprovider.AddGroup(&group, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Add("Location", fmt.Sprintf("%s/%s", groupPath, url.PathEscape(group.Name))) renderGroup(w, r, group.Name, claims, http.StatusCreated) } func updateGroup(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") group, err := dataprovider.GroupExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedGroup dataprovider.Group err = render.DecodeJSON(r.Body, &updatedGroup) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedGroup.ID = group.ID updatedGroup.Name = group.Name updatedGroup.UserSettings.FsConfig.SetEmptySecretsIfNil() updateEncryptedSecrets(&updatedGroup.UserSettings.FsConfig, &group.UserSettings.FsConfig) err = dataprovider.UpdateGroup(&updatedGroup, group.Users, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Group updated", http.StatusOK) } func renderGroup(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { group, err := dataprovider.GroupExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if hideConfidentialData(claims, r) { group.PrepareForRendering() } if status != http.StatusOK { ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) render.JSON(w, r.WithContext(ctx), group) } else { render.JSON(w, r, group) } } func getGroupByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") renderGroup(w, r, name, claims, http.StatusOK) } func deleteGroup(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") err = dataprovider.DeleteGroup(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Group deleted", http.StatusOK) } ================================================ FILE: internal/httpd/api_http_user.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "errors" "fmt" "io" "mime/multipart" "net/http" "os" "path" "strconv" "strings" "github.com/go-chi/render" "github.com/rs/xid" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, error) { claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return nil, fmt.Errorf("invalid token claims %w", err) } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { sendAPIResponse(w, r, nil, "Unable to retrieve your user", getRespStatus(err)) return nil, err } connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) return nil, err } baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) return connection, err } return connection, nil } func readUserFolder(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) lister, err := connection.ReadDir(name) if err != nil { sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err)) return } renderAPIDirContents(w, lister, false) } func createUserDir(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) connection.User.CheckFsRoot(connection.ID) //nolint:errcheck name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) if getBoolQueryParam(r, "mkdir_parents") { if err = connection.CheckParentDirs(path.Dir(name)); err != nil { sendAPIResponse(w, r, err, "Error checking parent directories", getMappedStatusCode(err)) return } } err = connection.CreateDir(name, true) if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to create directory %q", name), getMappedStatusCode(err)) return } sendAPIResponse(w, r, nil, fmt.Sprintf("Directory %q created", name), http.StatusCreated) } func deleteUserDir(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) err = connection.RemoveAll(name) if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to delete directory %q", name), getMappedStatusCode(err)) return } sendAPIResponse(w, r, nil, fmt.Sprintf("Directory %q deleted", name), http.StatusOK) } func renameUserFsEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path")) newName := connection.User.GetCleanedPath(r.URL.Query().Get("target")) if !connection.IsSameResource(oldName, newName) { if err := connection.Copy(oldName, newName); err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Cannot perform copy step to rename %q -> %q", oldName, newName), getMappedStatusCode(err)) return } if err := connection.RemoveAll(oldName); err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Cannot perform remove step to rename %q -> %q", oldName, newName), getMappedStatusCode(err)) return } } else { if err := connection.Rename(oldName, newName); err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename %q => %q", oldName, newName), getMappedStatusCode(err)) return } } sendAPIResponse(w, r, nil, fmt.Sprintf("%q renamed to %q", oldName, newName), http.StatusOK) } func copyUserFsEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) source := r.URL.Query().Get("path") target := r.URL.Query().Get("target") copyFromSource := strings.HasSuffix(source, "/") copyInTarget := strings.HasSuffix(target, "/") source = connection.User.GetCleanedPath(source) target = connection.User.GetCleanedPath(target) if copyFromSource { source += "/" } if copyInTarget { target += "/" } err = connection.Copy(source, target) if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to copy %q => %q", source, target), getMappedStatusCode(err)) return } sendAPIResponse(w, r, nil, fmt.Sprintf("%q copied to %q", source, target), http.StatusOK) } func getUserFile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) if name == "/" { sendAPIResponse(w, r, nil, "Please set the path to a valid file", http.StatusBadRequest) return } info, err := connection.Stat(name, 0) if err != nil { sendAPIResponse(w, r, err, "Unable to stat the requested file", getMappedStatusCode(err)) return } if info.IsDir() { sendAPIResponse(w, r, nil, fmt.Sprintf("Please set the path to a valid file, %q is a directory", name), http.StatusBadRequest) return } inline := r.URL.Query().Get("inline") != "" if status, err := downloadFile(w, r, connection, name, info, inline, nil); err != nil { resp := apiResponse{ Error: err.Error(), Message: http.StatusText(status), } ctx := r.Context() if status != 0 { ctx = context.WithValue(ctx, render.StatusCtxKey, status) } render.JSON(w, r.WithContext(ctx), resp) } } func setFileDirMetadata(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) metadata := make(map[string]int64) err := render.DecodeJSON(r.Body, &metadata) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } mTime, ok := metadata["modification_time"] if !ok || !r.URL.Query().Has("path") { sendAPIResponse(w, r, errors.New("please set a modification_time and a path"), "", http.StatusBadRequest) return } connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) attrs := common.StatAttributes{ Flags: common.StatAttrTimes, Atime: util.GetTimeFromMsecSinceEpoch(mTime), Mtime: util.GetTimeFromMsecSinceEpoch(mTime), } err = connection.SetStat(name, &attrs) if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to set metadata for path %q", name), getMappedStatusCode(err)) return } sendAPIResponse(w, r, nil, "OK", http.StatusOK) } func uploadUserFile(w http.ResponseWriter, r *http.Request) { if maxUploadFileSize > 0 { r.Body = http.MaxBytesReader(w, r.Body, maxUploadFileSize) } if !r.URL.Query().Has("path") { sendAPIResponse(w, r, errors.New("please set a file path"), "", http.StatusBadRequest) return } connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) connection.User.CheckFsRoot(connection.ID) //nolint:errcheck filePath := connection.User.GetCleanedPath(r.URL.Query().Get("path")) if getBoolQueryParam(r, "mkdir_parents") { if err = connection.CheckParentDirs(path.Dir(filePath)); err != nil { sendAPIResponse(w, r, err, "Error checking parent directories", getMappedStatusCode(err)) return } } doUploadFile(w, r, connection, filePath) //nolint:errcheck } func doUploadFile(w http.ResponseWriter, r *http.Request, connection *Connection, filePath string) error { writer, err := connection.getFileWriter(filePath) if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to write file %q", filePath), getMappedStatusCode(err)) return err } _, err = io.Copy(writer, r.Body) if err != nil { writer.Close() //nolint:errcheck sendAPIResponse(w, r, err, fmt.Sprintf("Error saving file %q", filePath), getMappedStatusCode(err)) return err } err = writer.Close() if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Error closing file %q", filePath), getMappedStatusCode(err)) return err } setModificationTimeFromHeader(r, connection, filePath) sendAPIResponse(w, r, nil, "Upload completed", http.StatusCreated) return nil } func uploadUserFiles(w http.ResponseWriter, r *http.Request) { if maxUploadFileSize > 0 { r.Body = http.MaxBytesReader(w, r.Body, maxUploadFileSize) } connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) if err := common.Connections.IsNewTransferAllowed(connection.User.Username); err != nil { connection.Log(logger.LevelInfo, "denying file write due to number of transfer limits") sendAPIResponse(w, r, err, "Denying file write due to transfer count limits", http.StatusConflict) return } transferQuota := connection.GetTransferQuota() if !transferQuota.HasUploadSpace() { connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits") sendAPIResponse(w, r, common.ErrQuotaExceeded, "Denying file write due to transfer quota limits", http.StatusRequestEntityTooLarge) return } t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection) r.Body = t err = r.ParseMultipartForm(maxMultipartMem) if err != nil { connection.RemoveTransfer(t) sendAPIResponse(w, r, err, "Unable to parse multipart form", http.StatusBadRequest) return } connection.RemoveTransfer(t) defer r.MultipartForm.RemoveAll() //nolint:errcheck parentDir := connection.User.GetCleanedPath(r.URL.Query().Get("path")) files := r.MultipartForm.File["filenames"] if len(files) == 0 { sendAPIResponse(w, r, nil, "No files uploaded!", http.StatusBadRequest) return } connection.User.CheckFsRoot(connection.ID) //nolint:errcheck if getBoolQueryParam(r, "mkdir_parents") { if err = connection.CheckParentDirs(parentDir); err != nil { sendAPIResponse(w, r, err, "Error checking parent directories", getMappedStatusCode(err)) return } } doUploadFiles(w, r, connection, parentDir, files) } func doUploadFiles(w http.ResponseWriter, r *http.Request, connection *Connection, parentDir string, files []*multipart.FileHeader, ) int { uploaded := 0 connection.User.UploadBandwidth = 0 for _, f := range files { file, err := f.Open() if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to read uploaded file %q", f.Filename), getMappedStatusCode(err)) return uploaded } defer file.Close() filePath := path.Join(parentDir, path.Base(util.CleanPath(f.Filename))) writer, err := connection.getFileWriter(filePath) if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to write file %q", f.Filename), getMappedStatusCode(err)) return uploaded } _, err = io.Copy(writer, file) if err != nil { writer.Close() //nolint:errcheck sendAPIResponse(w, r, err, fmt.Sprintf("Error saving file %q", f.Filename), getMappedStatusCode(err)) return uploaded } err = writer.Close() if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Error closing file %q", f.Filename), getMappedStatusCode(err)) return uploaded } uploaded++ } sendAPIResponse(w, r, nil, "Upload completed", http.StatusCreated) return uploaded } func deleteUserFile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) fs, p, err := connection.GetFsAndResolvedPath(name) if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to delete file %q", name), getMappedStatusCode(err)) return } var fi os.FileInfo if fi, err = fs.Lstat(p); err != nil { connection.Log(logger.LevelError, "failed to remove file %q: stat error: %+v", p, err) err = connection.GetFsError(fs, err) sendAPIResponse(w, r, err, fmt.Sprintf("Unable to delete file %q", name), getMappedStatusCode(err)) return } if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { connection.Log(logger.LevelDebug, "cannot remove %q is not a file/symlink", p) sendAPIResponse(w, r, err, fmt.Sprintf("Unable delete %q, it is not a file/symlink", name), http.StatusBadRequest) return } err = connection.RemoveFile(fs, p, name, fi) if err != nil { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to delete file %q", name), getMappedStatusCode(err)) return } sendAPIResponse(w, r, nil, fmt.Sprintf("File %q deleted", name), http.StatusOK) } func getUserFilesAsZipStream(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) var filesList []string err = render.DecodeJSON(r.Body, &filesList) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } baseDir := "/" for idx := range filesList { filesList[idx] = util.CleanPath(filesList[idx]) } filesList = util.RemoveDuplicates(filesList, false) w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", getCompressedFileName(connection.GetUsername(), filesList))) renderCompressedFiles(w, connection, baseDir, filesList, nil) } func getUserProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } user, err := dataprovider.UserExists(claims.Username, "") if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } resp := userProfile{ baseProfile: baseProfile{ Email: user.Email, Description: user.Description, AllowAPIKeyAuth: user.Filters.AllowAPIKeyAuth, }, AdditionalEmails: user.Filters.AdditionalEmails, PublicKeys: user.PublicKeys, TLSCerts: user.Filters.TLSCerts, } render.JSON(w, r, resp) } func updateUserProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var req userProfile err = render.DecodeJSON(r.Body, &req) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } user, userMerged, err := dataprovider.GetUserVariants(claims.Username, "") if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !userMerged.CanUpdateProfile() { sendAPIResponse(w, r, nil, "You are not allowed to change anything", http.StatusForbidden) return } if userMerged.CanManagePublicKeys() { user.PublicKeys = req.PublicKeys } if userMerged.CanManageTLSCerts() { user.Filters.TLSCerts = req.TLSCerts } if userMerged.CanChangeAPIKeyAuth() { user.Filters.AllowAPIKeyAuth = req.AllowAPIKeyAuth } if userMerged.CanChangeInfo() { user.Email = req.Email user.Filters.AdditionalEmails = req.AdditionalEmails user.Description = req.Description } if err := dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Profile updated", http.StatusOK) } func changeUserPassword(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var pwd pwdChange err := render.DecodeJSON(r.Body, &pwd) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } err = doChangeUserPassword(r, pwd.CurrentPassword, pwd.NewPassword, pwd.NewPassword) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } invalidateToken(r) sendAPIResponse(w, r, err, "Password updated", http.StatusOK) } func doChangeUserPassword(r *http.Request, currentPassword, newPassword, confirmNewPassword string) error { if currentPassword == "" || newPassword == "" || confirmNewPassword == "" { return util.NewI18nError( util.NewValidationError("please provide the current password and the new one two times"), util.I18nErrorChangePwdRequiredFields, ) } if newPassword != confirmNewPassword { return util.NewI18nError(util.NewValidationError("the two password fields do not match"), util.I18nErrorChangePwdNoMatch) } if currentPassword == newPassword { return util.NewI18nError( util.NewValidationError("the new password must be different from the current one"), util.I18nErrorChangePwdNoDifferent, ) } claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { return util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken) } _, err = dataprovider.CheckUserAndPass(claims.Username, currentPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), getProtocolFromRequest(r)) if err != nil { return util.NewI18nError(util.NewValidationError("current password does not match"), util.I18nErrorChangePwdCurrentNoMatch) } return dataprovider.UpdateUserPassword(claims.Username, newPassword, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) } func setModificationTimeFromHeader(r *http.Request, c *Connection, filePath string) { mTimeString := r.Header.Get(mTimeHeader) if mTimeString != "" { // we don't return an error here if we fail to set the modification time mTime, err := strconv.ParseInt(mTimeString, 10, 64) if err == nil { attrs := common.StatAttributes{ Flags: common.StatAttrTimes, Atime: util.GetTimeFromMsecSinceEpoch(mTime), Mtime: util.GetTimeFromMsecSinceEpoch(mTime), } err = c.SetStat(filePath, &attrs) c.Log(logger.LevelDebug, "requested modification time %v for file %q, error: %v", attrs.Mtime, filePath, err) } else { c.Log(logger.LevelInfo, "invalid modification time header was ignored: %v", mTimeString) } } } ================================================ FILE: internal/httpd/api_iplist.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "errors" "fmt" "net/http" "net/url" "strconv" "github.com/go-chi/chi/v5" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) func getIPListEntries(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, _, order, err := getSearchFilters(w, r) if err != nil { return } listType, _, err := getIPListPathParams(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } entries, err := dataprovider.GetIPListEntries(listType, r.URL.Query().Get("filter"), r.URL.Query().Get("from"), order, limit) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } render.JSON(w, r, entries) } func getIPListEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) listType, ipOrNet, err := getIPListPathParams(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } entry, err := dataprovider.IPListEntryExists(ipOrNet, listType) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } render.JSON(w, r, entry) } func addIPListEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var entry dataprovider.IPListEntry err = render.DecodeJSON(r.Body, &entry) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } err = dataprovider.AddIPListEntry(&entry, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Add("Location", fmt.Sprintf("%s/%d/%s", ipListsPath, entry.Type, url.PathEscape(entry.IPOrNet))) sendAPIResponse(w, r, nil, "Entry added", http.StatusCreated) } func updateIPListEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } listType, ipOrNet, err := getIPListPathParams(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } entry, err := dataprovider.IPListEntryExists(ipOrNet, listType) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedEntry dataprovider.IPListEntry err = render.DecodeJSON(r.Body, &updatedEntry) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedEntry.Type = entry.Type updatedEntry.IPOrNet = entry.IPOrNet err = dataprovider.UpdateIPListEntry(&updatedEntry, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Entry updated", http.StatusOK) } func deleteIPListEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } listType, ipOrNet, err := getIPListPathParams(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } err = dataprovider.DeleteIPListEntry(ipOrNet, listType, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Entry deleted", http.StatusOK) } func getIPListPathParams(r *http.Request) (dataprovider.IPListType, string, error) { listTypeString := chi.URLParam(r, "type") listType, err := strconv.Atoi(listTypeString) if err != nil { return dataprovider.IPListType(listType), "", errors.New("invalid list type") } if err := dataprovider.CheckIPListType(dataprovider.IPListType(listType)); err != nil { return dataprovider.IPListType(listType), "", err } return dataprovider.IPListType(listType), getURLParam(r, "ipornet"), nil } ================================================ FILE: internal/httpd/api_keys.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "fmt" "net/http" "net/url" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) func getAPIKeys(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } apiKeys, err := dataprovider.GetAPIKeys(limit, offset, order) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } render.JSON(w, r, apiKeys) } func getAPIKeyByID(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) keyID := getURLParam(r, "id") apiKey, err := dataprovider.APIKeyExists(keyID) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } apiKey.HideConfidentialData() render.JSON(w, r, apiKey) } func addAPIKey(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var apiKey dataprovider.APIKey err = render.DecodeJSON(r.Body, &apiKey) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } apiKey.ID = 0 apiKey.KeyID = "" apiKey.Key = "" apiKey.LastUseAt = 0 err = dataprovider.AddAPIKey(&apiKey, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } response := make(map[string]string) response["message"] = "API key created. This is the only time the API key is visible, please save it." response["key"] = apiKey.DisplayKey() w.Header().Add("Location", fmt.Sprintf("%s/%s", apiKeysPath, url.PathEscape(apiKey.KeyID))) w.Header().Add("X-Object-ID", apiKey.KeyID) ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusCreated) render.JSON(w, r.WithContext(ctx), response) } func updateAPIKey(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } keyID := getURLParam(r, "id") apiKey, err := dataprovider.APIKeyExists(keyID) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedAPIKey dataprovider.APIKey err = render.DecodeJSON(r.Body, &updatedAPIKey) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedAPIKey.KeyID = keyID updatedAPIKey.Key = apiKey.Key err = dataprovider.UpdateAPIKey(&updatedAPIKey, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "API key updated", http.StatusOK) } func deleteAPIKey(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) keyID := getURLParam(r, "id") claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } err = dataprovider.DeleteAPIKey(keyID, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "API key deleted", http.StatusOK) } ================================================ FILE: internal/httpd/api_maintenance.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/json" "errors" "fmt" "io" "net/http" "os" "path/filepath" "strconv" "strings" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) func validateBackupFile(outputFile string) (string, error) { if outputFile == "" { return "", errors.New("invalid or missing output-file") } if filepath.IsAbs(outputFile) { return "", fmt.Errorf("invalid output-file %q: it must be a relative path", outputFile) } if strings.Contains(outputFile, "..") { return "", fmt.Errorf("invalid output-file %q", outputFile) } outputFile = filepath.Join(dataprovider.GetBackupsPath(), outputFile) return outputFile, nil } func dumpData(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var outputFile, outputData, indent string var scopes []string if _, ok := r.URL.Query()["output-file"]; ok { outputFile = strings.TrimSpace(r.URL.Query().Get("output-file")) } if _, ok := r.URL.Query()["output-data"]; ok { outputData = strings.TrimSpace(r.URL.Query().Get("output-data")) } if _, ok := r.URL.Query()["indent"]; ok { indent = strings.TrimSpace(r.URL.Query().Get("indent")) } if _, ok := r.URL.Query()["scopes"]; ok { scopes = getCommaSeparatedQueryParam(r, "scopes") } if outputData != "1" { var err error outputFile, err = validateBackupFile(outputFile) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } err = os.MkdirAll(filepath.Dir(outputFile), 0700) if err != nil { logger.Error(logSender, "", "dumping data error: %v, output file: %q", err, outputFile) sendAPIResponse(w, r, err, "", getRespStatus(err)) return } logger.Debug(logSender, "", "dumping data to: %q", outputFile) } backup, err := dataprovider.DumpData(scopes) if err != nil { logger.Error(logSender, "", "dumping data error: %v, output file: %q", err, outputFile) sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if outputData == "1" { w.Header().Set("Content-Disposition", "attachment; filename=\"sftpgo-backup.json\"") render.JSON(w, r, backup) return } var dump []byte if indent == "1" { dump, err = json.MarshalIndent(backup, "", " ") } else { dump, err = json.Marshal(backup) } if err == nil { err = os.WriteFile(outputFile, dump, 0600) } if err != nil { logger.Warn(logSender, "", "dumping data error: %v, output file: %q", err, outputFile) sendAPIResponse(w, r, err, "", getRespStatus(err)) return } logger.Debug(logSender, "", "dumping data completed, output file: %q, error: %v", outputFile, err) sendAPIResponse(w, r, err, "Data saved", http.StatusOK) } func loadDataFromRequest(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, MaxRestoreSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } _, scanQuota, mode, err := getLoaddataOptions(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } content, err := io.ReadAll(r.Body) if err != nil || len(content) == 0 { if len(content) == 0 { err = util.NewValidationError("request body is required") } sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if err := restoreBackup(content, "", scanQuota, mode, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Data restored", http.StatusOK) } func loadData(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } inputFile, scanQuota, mode, err := getLoaddataOptions(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if !filepath.IsAbs(inputFile) { sendAPIResponse(w, r, fmt.Errorf("invalid input_file %q: it must be an absolute path", inputFile), "", http.StatusBadRequest) return } fi, err := os.Stat(inputFile) if err != nil { sendAPIResponse(w, r, fmt.Errorf("invalid input_file %q", inputFile), "", http.StatusBadRequest) return } if fi.Size() > MaxRestoreSize { sendAPIResponse(w, r, err, fmt.Sprintf("Unable to restore input file: %q size too big: %d/%d bytes", inputFile, fi.Size(), MaxRestoreSize), http.StatusBadRequest) return } content, err := os.ReadFile(inputFile) if err != nil { sendAPIResponse(w, r, fmt.Errorf("invalid input_file %q", inputFile), "", http.StatusBadRequest) return } if err := restoreBackup(content, inputFile, scanQuota, mode, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Data restored", http.StatusOK) } func restoreBackup(content []byte, inputFile string, scanQuota, mode int, executor, ipAddress, role string) error { dump, err := dataprovider.ParseDumpData(content) if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("invalid input_file %q", inputFile)), util.I18nErrorBackupFile, ) } if err = RestoreConfigs(dump.Configs, mode, executor, ipAddress, role); err != nil { return err } if err = RestoreIPListEntries(dump.IPLists, inputFile, mode, executor, ipAddress, role); err != nil { return err } if err = RestoreRoles(dump.Roles, inputFile, mode, executor, ipAddress, role); err != nil { return err } if err = RestoreFolders(dump.Folders, inputFile, mode, scanQuota, executor, ipAddress, role); err != nil { return err } if err = RestoreGroups(dump.Groups, inputFile, mode, executor, ipAddress, role); err != nil { return err } if err = RestoreUsers(dump.Users, inputFile, mode, scanQuota, executor, ipAddress, role); err != nil { return err } if err = RestoreAdmins(dump.Admins, inputFile, mode, executor, ipAddress, role); err != nil { return err } if err = RestoreAPIKeys(dump.APIKeys, inputFile, mode, executor, ipAddress, role); err != nil { return err } if err = RestoreShares(dump.Shares, inputFile, mode, executor, ipAddress, role); err != nil { return err } if err = RestoreEventActions(dump.EventActions, inputFile, mode, executor, ipAddress, role); err != nil { return err } if err = RestoreEventRules(dump.EventRules, inputFile, mode, executor, ipAddress, role, dump.Version); err != nil { return err } logger.Debug(logSender, "", "backup restored") return nil } func getLoaddataOptions(r *http.Request) (string, int, int, error) { var inputFile string var err error scanQuota := 0 restoreMode := 0 if _, ok := r.URL.Query()["input-file"]; ok { inputFile = strings.TrimSpace(r.URL.Query().Get("input-file")) } if _, ok := r.URL.Query()["scan-quota"]; ok { scanQuota, err = strconv.Atoi(r.URL.Query().Get("scan-quota")) if err != nil { err = fmt.Errorf("invalid scan_quota: %v", err) return inputFile, scanQuota, restoreMode, err } } if _, ok := r.URL.Query()["mode"]; ok { restoreMode, err = strconv.Atoi(r.URL.Query().Get("mode")) if err != nil { err = fmt.Errorf("invalid mode: %v", err) return inputFile, scanQuota, restoreMode, err } } return inputFile, scanQuota, restoreMode, err } // RestoreFolders restores the specified folders func RestoreFolders(folders []vfs.BaseVirtualFolder, inputFile string, mode, scanQuota int, executor, ipAddress, role string) error { for idx := range folders { folder := folders[idx] f, err := dataprovider.GetFolderByName(folder.Name) if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing folder %q not updated", folder.Name) continue } folder.ID = f.ID folder.Name = f.Name err = dataprovider.UpdateFolder(&folder, f.Users, f.Groups, executor, ipAddress, role) logger.Debug(logSender, "", "restoring existing folder %q, dump file: %q, error: %v", folder.Name, inputFile, err) } else { folder.Users = nil err = dataprovider.AddFolder(&folder, executor, ipAddress, role) logger.Debug(logSender, "", "adding new folder %q, dump file: %q, error: %v", folder.Name, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore folder %q: %w", folder.Name, err) } if scanQuota >= 1 { if common.QuotaScans.AddVFolderQuotaScan(folder.Name) { logger.Debug(logSender, "", "starting quota scan for restored folder: %q", folder.Name) go doFolderQuotaScan(folder) //nolint:errcheck } } } return nil } // RestoreShares restores the specified shares func RestoreShares(shares []dataprovider.Share, inputFile string, mode int, executor, ipAddress, role string, ) error { for idx := range shares { share := shares[idx] share.IsRestore = true s, err := dataprovider.ShareExists(share.ShareID, "") if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing share %q not updated", share.ShareID) continue } share.ID = s.ID err = dataprovider.UpdateShare(&share, executor, ipAddress, role) logger.Debug(logSender, "", "restoring existing share %q, dump file: %q, error: %v", share.ShareID, inputFile, err) } else { err = dataprovider.AddShare(&share, executor, ipAddress, role) logger.Debug(logSender, "", "adding new share %q, dump file: %q, error: %v", share.ShareID, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore share %q: %w", share.ShareID, err) } } return nil } // RestoreEventActions restores the specified event actions func RestoreEventActions(actions []dataprovider.BaseEventAction, inputFile string, mode int, executor, ipAddress, role string) error { for idx := range actions { action := actions[idx] a, err := dataprovider.EventActionExists(action.Name) if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing event action %q not updated", a.Name) continue } action.ID = a.ID err = dataprovider.UpdateEventAction(&action, executor, ipAddress, role) logger.Debug(logSender, "", "restoring event action %q, dump file: %q, error: %v", action.Name, inputFile, err) } else { err = dataprovider.AddEventAction(&action, executor, ipAddress, role) logger.Debug(logSender, "", "adding new event action %q, dump file: %q, error: %v", action.Name, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore event action %q: %w", action.Name, err) } } return nil } // RestoreEventRules restores the specified event rules func RestoreEventRules(rules []dataprovider.EventRule, inputFile string, mode int, executor, ipAddress, role string, dumpVersion int, ) error { for idx := range rules { rule := rules[idx] if dumpVersion < 15 { rule.Status = 1 } r, err := dataprovider.EventRuleExists(rule.Name) if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing event rule %q not updated", r.Name) continue } rule.ID = r.ID err = dataprovider.UpdateEventRule(&rule, executor, ipAddress, role) logger.Debug(logSender, "", "restoring event rule %q, dump file: %q, error: %v", rule.Name, inputFile, err) } else { err = dataprovider.AddEventRule(&rule, executor, ipAddress, role) logger.Debug(logSender, "", "adding new event rule %q, dump file: %q, error: %v", rule.Name, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore event rule %q: %w", rule.Name, err) } } return nil } // RestoreAPIKeys restores the specified API keys func RestoreAPIKeys(apiKeys []dataprovider.APIKey, inputFile string, mode int, executor, ipAddress, role string) error { for idx := range apiKeys { apiKey := apiKeys[idx] if apiKey.Key == "" { logger.Warn(logSender, "", "cannot restore empty API key") return fmt.Errorf("cannot restore an empty API key: %+v", apiKey) } k, err := dataprovider.APIKeyExists(apiKey.KeyID) if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing API key %q not updated", apiKey.KeyID) continue } apiKey.ID = k.ID err = dataprovider.UpdateAPIKey(&apiKey, executor, ipAddress, role) logger.Debug(logSender, "", "restoring existing API key %q, dump file: %q, error: %v", apiKey.KeyID, inputFile, err) } else { err = dataprovider.AddAPIKey(&apiKey, executor, ipAddress, role) logger.Debug(logSender, "", "adding new API key %q, dump file: %q, error: %v", apiKey.KeyID, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore API key %q: %w", apiKey.KeyID, err) } } return nil } // RestoreAdmins restores the specified admins func RestoreAdmins(admins []dataprovider.Admin, inputFile string, mode int, executor, ipAddress, role string) error { for idx := range admins { admin := admins[idx] a, err := dataprovider.AdminExists(admin.Username) if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing admin %q not updated", a.Username) continue } admin.ID = a.ID admin.Username = a.Username err = dataprovider.UpdateAdmin(&admin, executor, ipAddress, role) logger.Debug(logSender, "", "restoring existing admin %q, dump file: %q, error: %v", admin.Username, inputFile, err) } else { err = dataprovider.AddAdmin(&admin, executor, ipAddress, role) logger.Debug(logSender, "", "adding new admin %q, dump file: %q, error: %v", admin.Username, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore admin %q: %w", admin.Username, err) } } return nil } // RestoreConfigs restores the specified provider configs func RestoreConfigs(configs *dataprovider.Configs, mode int, executor, ipAddress, executorRole string, ) error { if configs == nil { return nil } c, err := dataprovider.GetConfigs() if err != nil { return fmt.Errorf("unable to restore configs, error loading existing from db: %w", err) } if c.UpdatedAt > 0 { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing configs not updated") return nil } } return dataprovider.UpdateConfigs(configs, executor, ipAddress, executorRole) } // RestoreIPListEntries restores the specified IP list entries func RestoreIPListEntries(entries []dataprovider.IPListEntry, inputFile string, mode int, executor, ipAddress, executorRole string, ) error { for idx := range entries { entry := entries[idx] e, err := dataprovider.IPListEntryExists(entry.IPOrNet, entry.Type) if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing IP list entry %s-%s not updated", e.Type.AsString(), e.IPOrNet) continue } err = dataprovider.UpdateIPListEntry(&entry, executor, ipAddress, executorRole) logger.Debug(logSender, "", "restoring existing IP list entry: %s-%s, dump file: %q, error: %v", entry.Type.AsString(), entry.IPOrNet, inputFile, err) } else { err = dataprovider.AddIPListEntry(&entry, executor, ipAddress, executorRole) logger.Debug(logSender, "", "adding new IP list entry %s-%s, dump file: %q, error: %v", entry.Type.AsString(), entry.IPOrNet, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore IP list entry %s-%s: %w", entry.Type.AsString(), entry.IPOrNet, err) } } return nil } // RestoreRoles restores the specified roles func RestoreRoles(roles []dataprovider.Role, inputFile string, mode int, executor, ipAddress, executorRole string) error { for idx := range roles { role := roles[idx] r, err := dataprovider.RoleExists(role.Name) if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing role %q not updated", r.Name) continue } role.ID = r.ID err = dataprovider.UpdateRole(&role, executor, ipAddress, executorRole) logger.Debug(logSender, "", "restoring existing role: %q, dump file: %q, error: %v", role.Name, inputFile, err) } else { err = dataprovider.AddRole(&role, executor, ipAddress, executorRole) logger.Debug(logSender, "", "adding new role: %q, dump file: %q, error: %v", role.Name, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore role %q: %w", role.Name, err) } } return nil } // RestoreGroups restores the specified groups func RestoreGroups(groups []dataprovider.Group, inputFile string, mode int, executor, ipAddress, role string) error { for idx := range groups { group := groups[idx] g, err := dataprovider.GroupExists(group.Name) if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing group %q not updated", g.Name) continue } group.ID = g.ID group.Name = g.Name err = dataprovider.UpdateGroup(&group, g.Users, executor, ipAddress, role) logger.Debug(logSender, "", "restoring existing group: %q, dump file: %q, error: %v", group.Name, inputFile, err) } else { err = dataprovider.AddGroup(&group, executor, ipAddress, role) logger.Debug(logSender, "", "adding new group: %q, dump file: %q, error: %v", group.Name, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore group %q: %w", group.Name, err) } } return nil } // RestoreUsers restores the specified users func RestoreUsers(users []dataprovider.User, inputFile string, mode, scanQuota int, executor, ipAddress, role string) error { for idx := range users { user := users[idx] u, err := dataprovider.UserExists(user.Username, "") if err == nil { if mode == 1 { logger.Debug(logSender, "", "loaddata mode 1, existing user %q not updated", u.Username) continue } user.ID = u.ID user.Username = u.Username err = dataprovider.UpdateUser(&user, executor, ipAddress, role) logger.Debug(logSender, "", "restoring existing user: %q, dump file: %q, error: %v", user.Username, inputFile, err) if mode == 2 && err == nil { disconnectUser(user.Username, executor, role) } } else { err = dataprovider.AddUser(&user, executor, ipAddress, role) logger.Debug(logSender, "", "adding new user: %q, dump file: %q, error: %v", user.Username, inputFile, err) } if err != nil { return fmt.Errorf("unable to restore user %q: %w", user.Username, err) } if scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions()) { user, err = dataprovider.GetUserWithGroupSettings(user.Username, "") if err == nil && common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) { logger.Debug(logSender, "", "starting quota scan for restored user: %q", user.Username) go doUserQuotaScan(&user) //nolint:errcheck } } } return nil } ================================================ FILE: internal/httpd/api_mfa.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "bytes" "errors" "fmt" "io" "net/http" "slices" "strconv" "strings" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( errRecoveryCodeForbidden = errors.New("recovery codes are not available with two-factor authentication disabled") ) type generateTOTPRequest struct { ConfigName string `json:"config_name"` } type generateTOTPResponse struct { ConfigName string `json:"config_name"` Issuer string `json:"issuer"` Secret string `json:"secret"` URL string `json:"url"` QRCode []byte `json:"qr_code"` } type validateTOTPRequest struct { ConfigName string `json:"config_name"` Passcode string `json:"passcode"` Secret string `json:"secret"` } type recoveryCode struct { Code string `json:"code"` Used bool `json:"used"` } func getTOTPConfigs(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) render.JSON(w, r, mfa.GetAvailableTOTPConfigs()) } func generateTOTPSecret(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var accountName string if hasUserAudience(claims) { accountName = fmt.Sprintf("User %q", claims.Username) } else { accountName = fmt.Sprintf("Admin %q", claims.Username) } var req generateTOTPRequest err = render.DecodeJSON(r.Body, &req) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } configName, key, qrCode, err := mfa.GenerateTOTPSecret(req.ConfigName, accountName) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } render.JSON(w, r, generateTOTPResponse{ ConfigName: configName, Issuer: key.Issuer(), Secret: key.Secret(), URL: key.URL(), QRCode: qrCode, }) } func getQRCode(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) img, err := mfa.GenerateQRCodeFromURL(r.URL.Query().Get("url"), 400, 400) if err != nil { sendAPIResponse(w, r, nil, "unable to generate qr code", http.StatusInternalServerError) return } imgSize := int64(len(img)) w.Header().Set("Content-Length", strconv.FormatInt(imgSize, 10)) w.Header().Set("Content-Type", "image/png") io.CopyN(w, bytes.NewBuffer(img), imgSize) //nolint:errcheck } func saveTOTPConfig(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } recoveryCodes := make([]dataprovider.RecoveryCode, 0, 12) for i := 0; i < 12; i++ { code := getNewRecoveryCode() recoveryCodes = append(recoveryCodes, dataprovider.RecoveryCode{Secret: kms.NewPlainSecret(code)}) } baseURL := webBaseClientPath if hasUserAudience(claims) { if err := saveUserTOTPConfig(claims.Username, r, recoveryCodes); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } } else { if err := saveAdminTOTPConfig(claims.Username, r, recoveryCodes); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } baseURL = webBasePath } if claims.MustSetTwoFactorAuth { // force logout defer func() { removeCookie(w, r, baseURL) }() } sendAPIResponse(w, r, nil, "TOTP configuration saved", http.StatusOK) } func validateTOTPPasscode(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var req validateTOTPRequest err := render.DecodeJSON(r.Body, &req) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } match, err := mfa.ValidateTOTPPasscode(req.ConfigName, req.Passcode, req.Secret) if !match || err != nil { sendAPIResponse(w, r, err, "Invalid passcode", http.StatusBadRequest) return } sendAPIResponse(w, r, nil, "Passcode successfully validated", http.StatusOK) } func getRecoveryCodes(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } recoveryCodes := make([]recoveryCode, 0, 12) var accountRecoveryCodes []dataprovider.RecoveryCode if hasUserAudience(claims) { user, err := dataprovider.UserExists(claims.Username, "") if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !user.Filters.TOTPConfig.Enabled { sendAPIResponse(w, r, errRecoveryCodeForbidden, "", http.StatusForbidden) return } accountRecoveryCodes = user.Filters.RecoveryCodes } else { admin, err := dataprovider.AdminExists(claims.Username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !admin.Filters.TOTPConfig.Enabled { sendAPIResponse(w, r, errRecoveryCodeForbidden, "", http.StatusForbidden) return } accountRecoveryCodes = admin.Filters.RecoveryCodes } for _, code := range accountRecoveryCodes { if err := code.Secret.Decrypt(); err != nil { sendAPIResponse(w, r, err, "Unable to decrypt recovery codes", getRespStatus(err)) return } recoveryCodes = append(recoveryCodes, recoveryCode{ Code: code.Secret.GetPayload(), Used: code.Used, }) } render.JSON(w, r, recoveryCodes) } func generateRecoveryCodes(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } recoveryCodes := make([]string, 0, 12) accountRecoveryCodes := make([]dataprovider.RecoveryCode, 0, 12) for i := 0; i < 12; i++ { code := getNewRecoveryCode() recoveryCodes = append(recoveryCodes, code) accountRecoveryCodes = append(accountRecoveryCodes, dataprovider.RecoveryCode{Secret: kms.NewPlainSecret(code)}) } if hasUserAudience(claims) { user, err := dataprovider.UserExists(claims.Username, "") if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !user.Filters.TOTPConfig.Enabled { sendAPIResponse(w, r, errRecoveryCodeForbidden, "", http.StatusForbidden) return } user.Filters.RecoveryCodes = accountRecoveryCodes if err := dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } } else { admin, err := dataprovider.AdminExists(claims.Username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !admin.Filters.TOTPConfig.Enabled { sendAPIResponse(w, r, errRecoveryCodeForbidden, "", http.StatusForbidden) return } admin.Filters.RecoveryCodes = accountRecoveryCodes if err := dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), admin.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } } render.JSON(w, r, recoveryCodes) } func getNewRecoveryCode() string { return fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID())) } func saveUserTOTPConfig(username string, r *http.Request, recoveryCodes []dataprovider.RecoveryCode) error { user, userMerged, err := dataprovider.GetUserVariants(username, "") if err != nil { return err } currentTOTPSecret := user.Filters.TOTPConfig.Secret user.Filters.TOTPConfig.Secret = nil err = render.DecodeJSON(r.Body, &user.Filters.TOTPConfig) if err != nil { return util.NewValidationError(fmt.Sprintf("unable to decode JSON body: %v", err)) } if !user.Filters.TOTPConfig.Enabled && len(userMerged.Filters.TwoFactorAuthProtocols) > 0 { return util.NewValidationError("two-factor authentication must be enabled") } for _, p := range userMerged.Filters.TwoFactorAuthProtocols { if !slices.Contains(user.Filters.TOTPConfig.Protocols, p) { return util.NewValidationError(fmt.Sprintf("totp: the following protocols are required: %q", strings.Join(userMerged.Filters.TwoFactorAuthProtocols, ", "))) } } if user.Filters.TOTPConfig.Secret == nil || !user.Filters.TOTPConfig.Secret.IsPlain() { user.Filters.TOTPConfig.Secret = currentTOTPSecret } if user.Filters.TOTPConfig.Enabled { if user.CountUnusedRecoveryCodes() < 5 && user.Filters.TOTPConfig.Enabled { user.Filters.RecoveryCodes = recoveryCodes } } else { user.Filters.RecoveryCodes = nil } return dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role) } func saveAdminTOTPConfig(username string, r *http.Request, recoveryCodes []dataprovider.RecoveryCode) error { admin, err := dataprovider.AdminExists(username) if err != nil { return err } currentTOTPSecret := admin.Filters.TOTPConfig.Secret admin.Filters.TOTPConfig.Secret = nil err = render.DecodeJSON(r.Body, &admin.Filters.TOTPConfig) if err != nil { return util.NewValidationError(fmt.Sprintf("unable to decode JSON body: %v", err)) } if !admin.Filters.TOTPConfig.Enabled && admin.Filters.RequireTwoFactor { return util.NewValidationError("two-factor authentication must be enabled") } if admin.Filters.TOTPConfig.Enabled { if admin.CountUnusedRecoveryCodes() < 5 && admin.Filters.TOTPConfig.Enabled { admin.Filters.RecoveryCodes = recoveryCodes } } else { admin.Filters.RecoveryCodes = nil } if admin.Filters.TOTPConfig.Secret == nil || !admin.Filters.TOTPConfig.Secret.IsPlain() { admin.Filters.TOTPConfig.Secret = currentTOTPSecret } return dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), admin.Role) } ================================================ FILE: internal/httpd/api_quota.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "errors" "fmt" "net/http" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( quotaUpdateModeAdd = "add" quotaUpdateModeReset = "reset" ) type quotaUsage struct { UsedQuotaSize int64 `json:"used_quota_size"` UsedQuotaFiles int `json:"used_quota_files"` } type transferQuotaUsage struct { UsedUploadDataTransfer int64 `json:"used_upload_data_transfer"` UsedDownloadDataTransfer int64 `json:"used_download_data_transfer"` } func getUsersQuotaScans(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } render.JSON(w, r, common.QuotaScans.GetUsersQuotaScans(claims.Role)) } func getFoldersQuotaScans(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) render.JSON(w, r, common.QuotaScans.GetVFoldersQuotaScans()) } func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var usage quotaUsage err := render.DecodeJSON(r.Body, &usage) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } doUpdateUserQuotaUsage(w, r, getURLParam(r, "username"), usage) } func updateFolderQuotaUsage(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var usage quotaUsage err := render.DecodeJSON(r.Body, &usage) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } doUpdateFolderQuotaUsage(w, r, getURLParam(r, "name"), usage) } func startUserQuotaScan(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) doStartUserQuotaScan(w, r, getURLParam(r, "username")) } func startFolderQuotaScan(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) doStartFolderQuotaScan(w, r, getURLParam(r, "name")) } func updateUserTransferQuotaUsage(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var usage transferQuotaUsage err = render.DecodeJSON(r.Body, &usage) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if usage.UsedUploadDataTransfer < 0 || usage.UsedDownloadDataTransfer < 0 { sendAPIResponse(w, r, errors.New("invalid used transfer quota parameters, negative values are not allowed"), "", http.StatusBadRequest) return } mode, err := getQuotaUpdateMode(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } user, err := dataprovider.GetUserWithGroupSettings(getURLParam(r, "username"), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if mode == quotaUpdateModeAdd && !user.HasTransferQuotaRestrictions() && dataprovider.GetQuotaTracking() == 2 { sendAPIResponse(w, r, errors.New("this user has no transfer quota restrictions, only reset mode is supported"), "", http.StatusBadRequest) return } err = dataprovider.UpdateUserTransferQuota(&user, usage.UsedUploadDataTransfer, usage.UsedDownloadDataTransfer, mode == quotaUpdateModeReset) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) } func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username string, usage quotaUsage) { claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } if usage.UsedQuotaFiles < 0 || usage.UsedQuotaSize < 0 { sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"), "", http.StatusBadRequest) return } mode, err := getQuotaUpdateMode(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } user, err := dataprovider.GetUserWithGroupSettings(username, claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if mode == quotaUpdateModeAdd && !user.HasQuotaRestrictions() && dataprovider.GetQuotaTracking() == 2 { sendAPIResponse(w, r, errors.New("this user has no quota restrictions, only reset mode is supported"), "", http.StatusBadRequest) return } if !common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) { sendAPIResponse(w, r, err, "A quota scan is in progress for this user", http.StatusConflict) return } defer common.QuotaScans.RemoveUserQuotaScan(user.Username) err = dataprovider.UpdateUserQuota(&user, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) } func doUpdateFolderQuotaUsage(w http.ResponseWriter, r *http.Request, name string, usage quotaUsage) { if usage.UsedQuotaFiles < 0 || usage.UsedQuotaSize < 0 { sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"), "", http.StatusBadRequest) return } mode, err := getQuotaUpdateMode(r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } folder, err := dataprovider.GetFolderByName(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) { sendAPIResponse(w, r, err, "A quota scan is in progress for this folder", http.StatusConflict) return } defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) err = dataprovider.UpdateVirtualFolderQuota(&folder, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) } else { sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) } } func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username string) { if dataprovider.GetQuotaTracking() == 0 { sendAPIResponse(w, r, nil, "Quota tracking is disabled!", http.StatusForbidden) return } claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } user, err := dataprovider.GetUserWithGroupSettings(username, claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) { sendAPIResponse(w, r, nil, fmt.Sprintf("Another scan is already in progress for user %q", username), http.StatusConflict) return } go doUserQuotaScan(&user) //nolint:errcheck sendAPIResponse(w, r, err, "Scan started", http.StatusAccepted) } func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string) { if dataprovider.GetQuotaTracking() == 0 { sendAPIResponse(w, r, nil, "Quota tracking is disabled!", http.StatusForbidden) return } folder, err := dataprovider.GetFolderByName(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) { sendAPIResponse(w, r, err, fmt.Sprintf("Another scan is already in progress for folder %q", name), http.StatusConflict) return } go doFolderQuotaScan(folder) //nolint:errcheck sendAPIResponse(w, r, err, "Scan started", http.StatusAccepted) } func doUserQuotaScan(user *dataprovider.User) error { defer common.QuotaScans.RemoveUserQuotaScan(user.Username) numFiles, size, err := user.ScanQuota() if err != nil { logger.Warn(logSender, "", "error scanning user quota %q: %v", user.Username, err) return err } err = dataprovider.UpdateUserQuota(user, numFiles, size, true) logger.Debug(logSender, "", "user quota scanned, user: %q, error: %v", user.Username, err) return err } func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error { defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) f := vfs.VirtualFolder{ BaseVirtualFolder: folder, VirtualPath: "/", } numFiles, size, err := f.ScanQuota() if err != nil { logger.Warn(logSender, "", "error scanning folder %q: %v", folder.Name, err) return err } err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true) logger.Debug(logSender, "", "virtual folder %q scanned, error: %v", folder.Name, err) return err } func getQuotaUpdateMode(r *http.Request) (string, error) { mode := quotaUpdateModeReset if _, ok := r.URL.Query()["mode"]; ok { mode = r.URL.Query().Get("mode") if mode != quotaUpdateModeReset && mode != quotaUpdateModeAdd { return "", errors.New("invalid mode") } } return mode, nil } ================================================ FILE: internal/httpd/api_retention.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "net/http" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/jwt" ) func getRetentionChecks(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } render.JSON(w, r, common.RetentionChecks.Get(claims.Role)) } ================================================ FILE: internal/httpd/api_role.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "fmt" "net/http" "net/url" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) func getRoles(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } roles, err := dataprovider.GetRoles(limit, offset, order, false) if err != nil { sendAPIResponse(w, r, err, "", http.StatusInternalServerError) return } render.JSON(w, r, roles) } func addRole(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var role dataprovider.Role err = render.DecodeJSON(r.Body, &role) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } err = dataprovider.AddRole(&role, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) } else { w.Header().Add("Location", fmt.Sprintf("%s/%s", rolesPath, url.PathEscape(role.Name))) renderRole(w, r, role.Name, http.StatusCreated) } } func updateRole(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") role, err := dataprovider.RoleExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedRole dataprovider.Role err = render.DecodeJSON(r.Body, &updatedRole) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedRole.ID = role.ID updatedRole.Name = role.Name err = dataprovider.UpdateRole(&updatedRole, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Role updated", http.StatusOK) } func renderRole(w http.ResponseWriter, r *http.Request, name string, status int) { role, err := dataprovider.RoleExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if status != http.StatusOK { ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) render.JSON(w, r.WithContext(ctx), role) } else { render.JSON(w, r, role) } } func getRoleByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) name := getURLParam(r, "name") renderRole(w, r, name, http.StatusOK) } func deleteRole(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") err = dataprovider.DeleteRole(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Role deleted", http.StatusOK) } ================================================ FILE: internal/httpd/api_shares.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "errors" "fmt" "net/http" "net/url" "os" "path" "slices" "strings" "time" "github.com/go-chi/render" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) func getShares(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } shares, err := dataprovider.GetShares(limit, offset, order, claims.Username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } render.JSON(w, r, shares) } func getShareByID(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } shareID := getURLParam(r, "id") share, err := dataprovider.ShareExists(shareID, claims.Username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } share.HideConfidentialData() render.JSON(w, r, share) } func addShare(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { sendAPIResponse(w, r, err, "Unable to retrieve your user", getRespStatus(err)) return } var share dataprovider.Share if user.Filters.DefaultSharesExpiration > 0 { share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(user.Filters.DefaultSharesExpiration))) } err = render.DecodeJSON(r.Body, &share) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if err := user.CheckMaxShareExpiration(util.GetTimeFromMsecSinceEpoch(share.ExpiresAt)); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } share.ID = 0 share.ShareID = util.GenerateUniqueID() share.LastUseAt = 0 share.Username = claims.Username if share.Name == "" { share.Name = share.ShareID } if share.Password == "" { if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password", http.StatusForbidden) return } } err = dataprovider.AddShare(&share, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Add("Location", fmt.Sprintf("%s/%s", userSharesPath, url.PathEscape(share.ShareID))) w.Header().Add("X-Object-ID", share.ShareID) sendAPIResponse(w, r, nil, "Share created", http.StatusCreated) } func updateShare(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { sendAPIResponse(w, r, err, "Unable to retrieve your user", getRespStatus(err)) return } shareID := getURLParam(r, "id") share, err := dataprovider.ShareExists(shareID, claims.Username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedShare dataprovider.Share err = render.DecodeJSON(r.Body, &updatedShare) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedShare.ShareID = shareID updatedShare.Username = claims.Username if updatedShare.Password == redactedSecret { updatedShare.Password = share.Password } if updatedShare.Password == "" { if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password", http.StatusForbidden) return } } if err := user.CheckMaxShareExpiration(util.GetTimeFromMsecSinceEpoch(updatedShare.ExpiresAt)); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } err = dataprovider.UpdateShare(&updatedShare, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "Share updated", http.StatusOK) } func deleteShare(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) shareID := getURLParam(r, "id") claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } err = dataprovider.DeleteShare(shareID, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Share deleted", http.StatusOK) } func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err := validateBrowsableShare(share, connection); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } name, err := getBrowsableSharedPath(share.Paths[0], r) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) lister, err := connection.ReadDir(name) if err != nil { sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err)) return } renderAPIDirContents(w, lister, true) } func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err := validateBrowsableShare(share, connection); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } name, err := getBrowsableSharedPath(share.Paths[0], r) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) info, err := connection.Stat(name, 1) if err != nil { sendAPIResponse(w, r, err, "Unable to stat the requested file", getMappedStatusCode(err)) return } if info.IsDir() { sendAPIResponse(w, r, nil, fmt.Sprintf("Please set the path to a valid file, %q is a directory", name), http.StatusBadRequest) return } inline := r.URL.Query().Get("inline") != "" dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck if status, err := downloadFile(w, r, connection, name, info, inline, &share); err != nil { dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck resp := apiResponse{ Error: err.Error(), Message: http.StatusText(status), } ctx := r.Context() if status != 0 { ctx = context.WithValue(ctx, render.StatusCtxKey, status) } render.JSON(w, r.WithContext(ctx), resp) } } func (s *httpdServer) downloadFromShare(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) compress := true var info os.FileInfo if len(share.Paths) == 1 { info, err = connection.Stat(share.Paths[0], 1) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if info.Mode().IsRegular() && r.URL.Query().Get("compress") == "false" { compress = false } } dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck if compress { transferQuota := connection.GetTransferQuota() if !transferQuota.HasDownloadSpace() { err = connection.GetReadQuotaExceededError() connection.Log(logger.LevelInfo, "denying share read due to quota limits") sendAPIResponse(w, r, err, "", getMappedStatusCode(err)) dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck return } baseDir := "/" if info != nil && info.IsDir() { baseDir = share.Paths[0] share.Paths[0] = "/" } w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"share-%v.zip\"", share.Name)) renderCompressedFiles(w, connection, baseDir, share.Paths, &share) return } if status, err := downloadFile(w, r, connection, share.Paths[0], info, false, &share); err != nil { dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck resp := apiResponse{ Error: err.Error(), Message: http.StatusText(status), } ctx := r.Context() if status != 0 { ctx = context.WithValue(ctx, render.StatusCtxKey, status) } render.JSON(w, r.WithContext(ctx), resp) } } func (s *httpdServer) uploadFileToShare(w http.ResponseWriter, r *http.Request) { if maxUploadFileSize > 0 { r.Body = http.MaxBytesReader(w, r.Body, maxUploadFileSize) } name := getURLParam(r, "name") validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeWrite, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } filePath := util.CleanPath(path.Join(share.Paths[0], name)) expectedPrefix := share.Paths[0] if !strings.HasSuffix(expectedPrefix, "/") { expectedPrefix += "/" } if !strings.HasPrefix(filePath, expectedPrefix) { sendAPIResponse(w, r, err, "Uploading outside the share is not allowed", http.StatusForbidden) return } dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) connection.User.CheckFsRoot(connection.ID) //nolint:errcheck if getBoolQueryParam(r, "mkdir_parents") { if err = connection.CheckParentDirs(path.Dir(filePath)); err != nil { sendAPIResponse(w, r, err, "Error checking parent directories", getMappedStatusCode(err)) return } } if err := doUploadFile(w, r, connection, filePath); err != nil { dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck } } func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request) { if maxUploadFileSize > 0 { r.Body = http.MaxBytesReader(w, r.Body, maxUploadFileSize) } validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeWrite, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err := common.Connections.IsNewTransferAllowed(connection.User.Username); err != nil { connection.Log(logger.LevelInfo, "denying file write due to number of transfer limits") sendAPIResponse(w, r, err, "Denying file write due to transfer count limits", http.StatusConflict) return } transferQuota := connection.GetTransferQuota() if !transferQuota.HasUploadSpace() { connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits") sendAPIResponse(w, r, common.ErrQuotaExceeded, "Denying file write due to transfer quota limits", http.StatusRequestEntityTooLarge) return } if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection) r.Body = t err = r.ParseMultipartForm(maxMultipartMem) if err != nil { connection.RemoveTransfer(t) sendAPIResponse(w, r, err, "Unable to parse multipart form", http.StatusBadRequest) return } connection.RemoveTransfer(t) defer r.MultipartForm.RemoveAll() //nolint:errcheck files := r.MultipartForm.File["filenames"] if len(files) == 0 { sendAPIResponse(w, r, nil, "No files uploaded!", http.StatusBadRequest) return } if share.MaxTokens > 0 { if len(files) > (share.MaxTokens - share.UsedTokens) { sendAPIResponse(w, r, nil, "Allowed usage exceeded", http.StatusBadRequest) return } } dataprovider.UpdateShareLastUse(&share, len(files)) //nolint:errcheck connection.User.CheckFsRoot(connection.ID) //nolint:errcheck numUploads := doUploadFiles(w, r, connection, share.Paths[0], files) if numUploads != len(files) { dataprovider.UpdateShareLastUse(&share, numUploads-len(files)) //nolint:errcheck } } func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.Context, *jwt.Claims, error) { token, err := jwt.VerifyRequest(s.tokenAuth, r, jwt.TokenFromCookie) if err != nil || token == nil { return nil, nil, errInvalidToken } tokenString := jwt.TokenFromCookie(r) if tokenString == "" || invalidatedJWTTokens.Get(tokenString) { return nil, nil, errInvalidToken } if !token.Audience.Contains(tokenAudienceWebShare) { logger.Debug(logSender, "", "invalid token audience for share %q", shareID) return nil, nil, errInvalidToken } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := validateIPForToken(token, ipAddr); err != nil { logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", shareID, ipAddr) return nil, nil, err } if token.Username != shareID { logger.Debug(logSender, "", "token not valid for share %q", shareID) return nil, nil, errInvalidToken } ctx := jwt.NewContext(r.Context(), token, nil) return ctx, token, nil } func (s *httpdServer) checkWebClientShareCredentials(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) error { doRedirect := func() { redirectURL := path.Join(webClientPubSharesPath, share.ShareID, fmt.Sprintf("login?next=%s", url.QueryEscape(r.RequestURI))) http.Redirect(w, r, redirectURL, http.StatusFound) } if _, _, err := s.getShareClaims(r, share.ShareID); err != nil { doRedirect() return err } return nil } func (s *httpdServer) checkPublicShare(w http.ResponseWriter, r *http.Request, validScopes []dataprovider.ShareScope, ) (dataprovider.Share, *Connection, error) { isWebClient := isWebClientRequest(r) renderError := func(err error, message string, statusCode int) { if isWebClient { s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, statusCode, err, message) } else { sendAPIResponse(w, r, err, message, statusCode) } } shareID := getURLParam(r, "id") share, err := dataprovider.ShareExists(shareID, "") if err != nil { statusCode := getRespStatus(err) if statusCode == http.StatusNotFound { err = util.NewI18nError(errors.New("share does not exist"), util.I18nError404Message) } renderError(err, "", statusCode) return share, nil, err } if !slices.Contains(validScopes, share.Scope) { err := errors.New("invalid share scope") renderError(util.NewI18nError(err, util.I18nErrorShareScope), "", http.StatusForbidden) return share, nil, err } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ok, err := share.IsUsable(ipAddr) if !ok || err != nil { renderError(err, "", getRespStatus(err)) return share, nil, err } if share.Password != "" { if isWebClient { if err := s.checkWebClientShareCredentials(w, r, &share); err != nil { return share, nil, dataprovider.ErrInvalidCredentials } } else { _, password, ok := r.BasicAuth() if !ok { w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) renderError(dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return share, nil, dataprovider.ErrInvalidCredentials } match, err := share.CheckCredentials(password) if !match || err != nil { handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) renderError(dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return share, nil, dataprovider.ErrInvalidCredentials } } common.DelayLogin(nil) } user, err := getUserForShare(share) if err != nil { renderError(err, "", getRespStatus(err)) return share, nil, err } connID := xid.New().String() baseConn := common.NewBaseConnection(connID, common.ProtocolHTTPShare, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) connection := newConnection(baseConn, w, r) return share, connection, nil } func getUserForShare(share dataprovider.Share) (dataprovider.User, error) { user, err := dataprovider.GetUserWithGroupSettings(share.Username, "") if err != nil { return user, err } if !user.CanManageShares() { return user, util.NewI18nError(util.NewRecordNotFoundError("this share does not exist"), util.I18nError404Message) } if share.Password == "" && slices.Contains(user.Filters.WebClient, sdk.WebClientShareNoPasswordDisabled) { return user, util.NewI18nError( fmt.Errorf("sharing without a password was disabled: %w", os.ErrPermission), util.I18nError403Message, ) } if user.MustSetSecondFactorForProtocol(common.ProtocolHTTP) { return user, util.NewI18nError( util.NewMethodDisabledError("two-factor authentication requirements not met"), util.I18nError403Message, ) } return user, nil } func validateBrowsableShare(share dataprovider.Share, connection *Connection) error { if len(share.Paths) != 1 { return util.NewI18nError( util.NewValidationError("a share with multiple paths is not browsable"), util.I18nErrorShareBrowsePaths, ) } basePath := share.Paths[0] info, err := connection.Stat(basePath, 0) if err != nil { connection.CloseFS() //nolint:errcheck return util.NewI18nError( fmt.Errorf("unable to check the share directory: %w", err), util.I18nErrorShareInvalidPath, ) } if !info.IsDir() { return util.NewI18nError( util.NewValidationError("the shared object is not a directory and so it is not browsable"), util.I18nErrorShareBrowseNoDir, ) } return nil } func getBrowsableSharedPath(shareBasePath string, r *http.Request) (string, error) { name := util.CleanPath(path.Join(shareBasePath, r.URL.Query().Get("path"))) if shareBasePath == "/" { return name, nil } if name != shareBasePath && !strings.HasPrefix(name, shareBasePath+"/") { return "", util.NewI18nError( util.NewValidationError(fmt.Sprintf("Invalid path %q", r.URL.Query().Get("path"))), util.I18nErrorPathInvalid, ) } return name, nil } ================================================ FILE: internal/httpd/api_user.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "fmt" "net/http" "net/url" "strconv" "time" "github.com/go-chi/render" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) func getUsers(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) limit, offset, order, err := getSearchFilters(w, r) if err != nil { return } claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } users, err := dataprovider.GetUsers(limit, offset, order, claims.Role) if err == nil { render.JSON(w, r, users) } else { sendAPIResponse(w, r, err, "", http.StatusInternalServerError) } } func getUserByUsername(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } username := getURLParam(r, "username") renderUser(w, r, username, claims, http.StatusOK) } func renderUser(w http.ResponseWriter, r *http.Request, username string, claims *jwt.Claims, status int) { user, err := dataprovider.UserExists(username, claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if hideConfidentialData(claims, r) { user.PrepareForRendering() } if status != http.StatusOK { ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) render.JSON(w, r.WithContext(ctx), user) } else { render.JSON(w, r, user) } } func addUser(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } admin, err := dataprovider.AdminExists(claims.Username) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var user dataprovider.User if admin.Filters.Preferences.DefaultUsersExpiration > 0 { user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration))) } err = render.DecodeJSON(r.Body, &user) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if claims.Role != "" { user.Role = claims.Role } user.LastPasswordChange = 0 user.Filters.RecoveryCodes = nil user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: false, } err = dataprovider.AddUser(&user, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } w.Header().Add("Location", fmt.Sprintf("%s/%s", userPath, url.PathEscape(user.Username))) renderUser(w, r, user.Username, claims, http.StatusCreated) } func disableUser2FA(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } username := getURLParam(r, "username") user, err := dataprovider.UserExists(username, claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } if !user.Filters.TOTPConfig.Enabled { sendAPIResponse(w, r, nil, "two-factor authentication is not enabled", http.StatusBadRequest) return } user.Filters.RecoveryCodes = nil user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: false, } if err := dataprovider.UpdateUser(&user, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, nil, "2FA disabled", http.StatusOK) } func updateUser(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } username := getURLParam(r, "username") disconnect := 0 if _, ok := r.URL.Query()["disconnect"]; ok { disconnect, err = strconv.Atoi(r.URL.Query().Get("disconnect")) if err != nil { err = fmt.Errorf("invalid disconnect parameter: %v", err) sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } } user, err := dataprovider.UserExists(username, claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } var updatedUser dataprovider.User updatedUser.Password = user.Password err = render.DecodeJSON(r.Body, &updatedUser) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } updatedUser.ID = user.ID updatedUser.Username = user.Username updatedUser.Filters.RecoveryCodes = user.Filters.RecoveryCodes updatedUser.Filters.TOTPConfig = user.Filters.TOTPConfig updatedUser.LastPasswordChange = user.LastPasswordChange updatedUser.SetEmptySecretsIfNil() updateEncryptedSecrets(&updatedUser.FsConfig, &user.FsConfig) if claims.Role != "" { updatedUser.Role = claims.Role } err = dataprovider.UpdateUser(&updatedUser, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "User updated", http.StatusOK) if disconnect == 1 { disconnectUser(user.Username, claims.Username, claims.Role) } } func deleteUser(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } username := getURLParam(r, "username") err = dataprovider.DeleteUser(username, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "User deleted", http.StatusOK) disconnectUser(dataprovider.ConvertName(username), claims.Username, claims.Role) } func forgotUserPassword(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) if !smtp.IsEnabled() { sendAPIResponse(w, r, nil, "No SMTP configuration", http.StatusBadRequest) return } err := handleForgotPassword(r, getURLParam(r, "username"), false) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Check your email for the confirmation code", http.StatusOK) } func resetUserPassword(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var req pwdReset err := render.DecodeJSON(r.Body, &req) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } _, _, err = handleResetPassword(r, req.Code, req.Password, req.Password, false) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } sendAPIResponse(w, r, err, "Password reset successful", http.StatusOK) } func disconnectUser(username, admin, role string) { for _, stat := range common.Connections.GetStats("") { if stat.Username == username { common.Connections.Close(stat.ConnectionID, "") } } for _, stat := range getNodesConnections(admin, role) { if stat.Username == username { n, err := dataprovider.GetNodeByName(stat.Node) if err != nil { logger.Warn(logSender, "", "unable to disconnect user %q, error getting node %q: %v", username, stat.Node, err) continue } perms := []string{dataprovider.PermAdminCloseConnections} uri := fmt.Sprintf("%s/%s", activeConnectionsPath, stat.ConnectionID) if err := n.SendDeleteRequest(admin, role, uri, perms); err != nil { logger.Warn(logSender, "", "unable to disconnect user %q from node %q, error: %v", username, n.Name, err) } } } } func updateEncryptedSecrets(fsConfig *vfs.Filesystem, currentFsConfig *vfs.Filesystem) { // we use the new access secret if plain or empty, otherwise the old value switch fsConfig.Provider { case sdk.S3FilesystemProvider: if fsConfig.S3Config.AccessSecret.IsNotPlainAndNotEmpty() { fsConfig.S3Config.AccessSecret = currentFsConfig.S3Config.AccessSecret } if fsConfig.S3Config.SSECustomerKey.IsNotPlainAndNotEmpty() { fsConfig.S3Config.SSECustomerKey = currentFsConfig.S3Config.SSECustomerKey } case sdk.AzureBlobFilesystemProvider: if fsConfig.AzBlobConfig.AccountKey.IsNotPlainAndNotEmpty() { fsConfig.AzBlobConfig.AccountKey = currentFsConfig.AzBlobConfig.AccountKey } if fsConfig.AzBlobConfig.SASURL.IsNotPlainAndNotEmpty() { fsConfig.AzBlobConfig.SASURL = currentFsConfig.AzBlobConfig.SASURL } case sdk.GCSFilesystemProvider: // for GCS credentials will be cleared if we enable automatic credentials // so keep the old credentials here if no new credentials are provided if !fsConfig.GCSConfig.Credentials.IsPlain() { fsConfig.GCSConfig.Credentials = currentFsConfig.GCSConfig.Credentials } case sdk.CryptedFilesystemProvider: if fsConfig.CryptConfig.Passphrase.IsNotPlainAndNotEmpty() { fsConfig.CryptConfig.Passphrase = currentFsConfig.CryptConfig.Passphrase } case sdk.SFTPFilesystemProvider: updateSFTPFsEncryptedSecrets(fsConfig, currentFsConfig) case sdk.HTTPFilesystemProvider: updateHTTPFsEncryptedSecrets(fsConfig, currentFsConfig) } } func updateSFTPFsEncryptedSecrets(fsConfig *vfs.Filesystem, currentFsConfig *vfs.Filesystem) { if fsConfig.SFTPConfig.Password.IsNotPlainAndNotEmpty() { fsConfig.SFTPConfig.Password = currentFsConfig.SFTPConfig.Password } if fsConfig.SFTPConfig.PrivateKey.IsNotPlainAndNotEmpty() { fsConfig.SFTPConfig.PrivateKey = currentFsConfig.SFTPConfig.PrivateKey } if fsConfig.SFTPConfig.KeyPassphrase.IsNotPlainAndNotEmpty() { fsConfig.SFTPConfig.KeyPassphrase = currentFsConfig.SFTPConfig.KeyPassphrase } } func updateHTTPFsEncryptedSecrets(fsConfig *vfs.Filesystem, currentFsConfig *vfs.Filesystem) { if fsConfig.HTTPConfig.Password.IsNotPlainAndNotEmpty() { fsConfig.HTTPConfig.Password = currentFsConfig.HTTPConfig.Password } if fsConfig.HTTPConfig.APIKey.IsNotPlainAndNotEmpty() { fsConfig.HTTPConfig.APIKey = currentFsConfig.HTTPConfig.APIKey } } ================================================ FILE: internal/httpd/api_utils.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "io/fs" "mime" "net/http" "net/url" "os" "path" "slices" "strconv" "strings" "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" "github.com/klauspost/compress/zip" "github.com/rs/xid" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) type pwdChange struct { CurrentPassword string `json:"current_password"` NewPassword string `json:"new_password"` } type pwdReset struct { Code string `json:"code"` Password string `json:"password"` } type baseProfile struct { Email string `json:"email,omitempty"` Description string `json:"description,omitempty"` AllowAPIKeyAuth bool `json:"allow_api_key_auth"` } type adminProfile struct { baseProfile } type userProfile struct { baseProfile AdditionalEmails []string `json:"additional_emails,omitempty"` PublicKeys []string `json:"public_keys,omitempty"` TLSCerts []string `json:"tls_certs,omitempty"` } func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) { var errorString string if errors.Is(err, util.ErrNotFound) { errorString = http.StatusText(http.StatusNotFound) } else if err != nil { errorString = err.Error() } resp := apiResponse{ Error: errorString, Message: message, } ctx := context.WithValue(r.Context(), render.StatusCtxKey, code) render.JSON(w, r.WithContext(ctx), resp) } func getRespStatus(err error) int { if errors.Is(err, util.ErrValidation) { return http.StatusBadRequest } if errors.Is(err, util.ErrMethodDisabled) { return http.StatusForbidden } if errors.Is(err, util.ErrNotFound) { return http.StatusNotFound } if errors.Is(err, fs.ErrNotExist) { return http.StatusBadRequest } if errors.Is(err, fs.ErrPermission) || errors.Is(err, dataprovider.ErrLoginNotAllowedFromIP) { return http.StatusForbidden } if errors.Is(err, plugin.ErrNoSearcher) || errors.Is(err, dataprovider.ErrNotImplemented) { return http.StatusNotImplemented } if errors.Is(err, dataprovider.ErrDuplicatedKey) || errors.Is(err, dataprovider.ErrForeignKeyViolated) { return http.StatusConflict } return http.StatusInternalServerError } // mappig between fs errors for HTTP protocol and HTTP response status codes func getMappedStatusCode(err error) int { var statusCode int switch { case errors.Is(err, fs.ErrPermission): statusCode = http.StatusForbidden case errors.Is(err, common.ErrReadQuotaExceeded): statusCode = http.StatusForbidden case errors.Is(err, fs.ErrNotExist): statusCode = http.StatusNotFound case errors.Is(err, common.ErrQuotaExceeded): statusCode = http.StatusRequestEntityTooLarge case errors.Is(err, common.ErrOpUnsupported): statusCode = http.StatusBadRequest default: if _, ok := err.(*http.MaxBytesError); ok { statusCode = http.StatusRequestEntityTooLarge } else { statusCode = http.StatusInternalServerError } } return statusCode } func getURLParam(r *http.Request, key string) string { v := chi.URLParam(r, key) unescaped, err := url.PathUnescape(v) if err != nil { return v } return unescaped } func getURLPath(r *http.Request) string { rctx := chi.RouteContext(r.Context()) if rctx != nil && rctx.RoutePath != "" { return rctx.RoutePath } return r.URL.Path } func getCommaSeparatedQueryParam(r *http.Request, key string) []string { var result []string for val := range strings.SplitSeq(r.URL.Query().Get(key), ",") { val = strings.TrimSpace(val) if val != "" { result = append(result, val) } } return util.RemoveDuplicates(result, false) } func getBoolQueryParam(r *http.Request, param string) bool { return r.URL.Query().Get(param) == "true" } func getActiveConnections(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } stats := common.Connections.GetStats(claims.Role) if claims.NodeID == "" { stats = append(stats, getNodesConnections(claims.Username, claims.Role)...) } render.JSON(w, r, stats) } func handleCloseConnection(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } connectionID := getURLParam(r, "connectionID") if connectionID == "" { sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest) return } node := r.URL.Query().Get("node") if node == "" || node == dataprovider.GetNodeName() { if common.Connections.Close(connectionID, claims.Role) { sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK) } else { sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) } return } n, err := dataprovider.GetNodeByName(node) if err != nil { logger.Warn(logSender, "", "unable to get node with name %q: %v", node, err) status := getRespStatus(err) sendAPIResponse(w, r, nil, http.StatusText(status), status) return } perms := []string{dataprovider.PermAdminCloseConnections} uri := fmt.Sprintf("%s/%s", activeConnectionsPath, connectionID) if err := n.SendDeleteRequest(claims.Username, claims.Role, uri, perms); err != nil { logger.Warn(logSender, "", "unable to delete connection id %q from node %q: %v", connectionID, n.Name, err) sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) return } sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK) } // getNodesConnections returns the active connections from other nodes. // Errors are silently ignored func getNodesConnections(admin, role string) []common.ConnectionStatus { nodes, err := dataprovider.GetNodes() if err != nil || len(nodes) == 0 { return nil } var results []common.ConnectionStatus var mu sync.Mutex var wg sync.WaitGroup for _, n := range nodes { wg.Add(1) go func(node dataprovider.Node) { defer wg.Done() var stats []common.ConnectionStatus perms := []string{dataprovider.PermAdminViewConnections} if err := node.SendGetRequest(admin, role, activeConnectionsPath, perms, &stats); err != nil { logger.Warn(logSender, "", "unable to get connections from node %s: %v", node.Name, err) return } mu.Lock() results = append(results, stats...) mu.Unlock() }(n) } wg.Wait() return results } func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, error) { var err error limit := 100 offset := 0 order := dataprovider.OrderASC if _, ok := r.URL.Query()["limit"]; ok { limit, err = strconv.Atoi(r.URL.Query().Get("limit")) if err != nil { err = errors.New("invalid limit") sendAPIResponse(w, r, err, "", http.StatusBadRequest) return limit, offset, order, err } if limit > 500 { limit = 500 } } if _, ok := r.URL.Query()["offset"]; ok { offset, err = strconv.Atoi(r.URL.Query().Get("offset")) if err != nil { err = errors.New("invalid offset") sendAPIResponse(w, r, err, "", http.StatusBadRequest) return limit, offset, order, err } } if _, ok := r.URL.Query()["order"]; ok { order = r.URL.Query().Get("order") if order != dataprovider.OrderASC && order != dataprovider.OrderDESC { err = errors.New("invalid order") sendAPIResponse(w, r, err, "", http.StatusBadRequest) return limit, offset, order, err } } return limit, offset, order, err } func renderAPIDirContents(w http.ResponseWriter, lister vfs.DirLister, omitNonRegularFiles bool) { defer lister.Close() dataGetter := func(limit, _ int) ([]byte, int, error) { contents, err := lister.Next(limit) if errors.Is(err, io.EOF) { err = nil } if err != nil { return nil, 0, err } results := make([]map[string]any, 0, len(contents)) for _, info := range contents { if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() { continue } res := make(map[string]any) res["name"] = info.Name() if info.Mode().IsRegular() { res["size"] = info.Size() } res["mode"] = info.Mode() res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339) results = append(results, res) } data, err := json.Marshal(results) count := limit if len(results) == 0 { count = 0 } return data, count, err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func streamData(w io.Writer, data []byte) { b := bytes.NewBuffer(data) _, err := io.CopyN(w, b, int64(len(data))) if err != nil { panic(http.ErrAbortHandler) } } func streamJSONArray(w http.ResponseWriter, chunkSize int, dataGetter func(limit, offset int) ([]byte, int, error)) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Accept-Ranges", "none") w.WriteHeader(http.StatusOK) streamData(w, []byte("[")) offset := 0 for { data, count, err := dataGetter(chunkSize, offset) if err != nil { panic(http.ErrAbortHandler) } if count == 0 { break } if offset > 0 { streamData(w, []byte(",")) } streamData(w, data[1:len(data)-1]) if count < chunkSize { break } offset += count } streamData(w, []byte("]")) } func renderPNGImage(w http.ResponseWriter, r *http.Request, b []byte) { if len(b) == 0 { ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusNotFound) render.PlainText(w, r.WithContext(ctx), http.StatusText(http.StatusNotFound)) return } w.Header().Set("Content-Type", "image/png") streamData(w, b) } func getCompressedFileName(username string, files []string) string { if len(files) == 1 { name := path.Base(files[0]) return fmt.Sprintf("%s-%s.zip", username, strings.TrimSuffix(name, path.Ext(name))) } return fmt.Sprintf("%s-download.zip", username) } func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir string, files []string, share *dataprovider.Share, ) { conn.User.CheckFsRoot(conn.ID) //nolint:errcheck w.Header().Set("Content-Type", "application/zip") w.Header().Set("Accept-Ranges", "none") w.Header().Set("Content-Transfer-Encoding", "binary") w.WriteHeader(http.StatusOK) wr := zip.NewWriter(w) for _, file := range files { fullPath := util.CleanPath(path.Join(baseDir, file)) if err := addZipEntry(wr, conn, fullPath, baseDir, nil, 0); err != nil { if share != nil { dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck } panic(http.ErrAbortHandler) } } if err := wr.Close(); err != nil { conn.Log(logger.LevelError, "unable to close zip file: %v", err) if share != nil { dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck } panic(http.ErrAbortHandler) } } func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string, info os.FileInfo, recursion int) error { if recursion >= util.MaxRecursion { conn.Log(logger.LevelDebug, "unable to add zip entry %q, recursion too depth: %d", entryPath, recursion) return util.ErrRecursionTooDeep } recursion++ var err error if info == nil { info, err = conn.Stat(entryPath, 1) if err != nil { conn.Log(logger.LevelDebug, "unable to add zip entry %q, stat error: %v", entryPath, err) return err } } entryName, err := getZipEntryName(entryPath, baseDir) if err != nil { conn.Log(logger.LevelError, "unable to get zip entry name: %v", err) return err } if info.IsDir() { _, err = wr.CreateHeader(&zip.FileHeader{ Name: entryName + "/", Method: zip.Deflate, Modified: info.ModTime(), }) if err != nil { conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) return err } lister, err := conn.ReadDir(entryPath) if err != nil { conn.Log(logger.LevelDebug, "unable to add zip entry %q, get list dir error: %v", entryPath, err) return err } defer lister.Close() for { contents, err := lister.Next(vfs.ListerBatchSize) finished := errors.Is(err, io.EOF) if err != nil && !finished { return err } for _, info := range contents { fullPath := util.CleanPath(path.Join(entryPath, info.Name())) if err := addZipEntry(wr, conn, fullPath, baseDir, info, recursion); err != nil { return err } } if finished { return nil } } } if !info.Mode().IsRegular() { // we only allow regular files conn.Log(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath) return nil } return addFileToZipEntry(wr, conn, entryPath, entryName, info) } func addFileToZipEntry(wr *zip.Writer, conn *Connection, entryPath, entryName string, info os.FileInfo) error { reader, err := conn.getFileReader(entryPath, 0, http.MethodGet) if err != nil { conn.Log(logger.LevelDebug, "unable to add zip entry %q, cannot open file: %v", entryPath, err) return err } defer reader.Close() f, err := wr.CreateHeader(&zip.FileHeader{ Name: entryName, Method: zip.Deflate, Modified: info.ModTime(), }) if err != nil { conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) return err } _, err = io.Copy(f, reader) return err } func getZipEntryName(entryPath, baseDir string) (string, error) { if !strings.HasPrefix(entryPath, baseDir) { return "", fmt.Errorf("entry path %q is outside base dir %q", entryPath, baseDir) } entryPath = strings.TrimPrefix(entryPath, baseDir) return strings.TrimPrefix(entryPath, "/"), nil } func checkDownloadFileFromShare(share *dataprovider.Share, info os.FileInfo) error { if share != nil && !info.Mode().IsRegular() { return util.NewValidationError("non regular files are not supported for shares") } return nil } func downloadFile(w http.ResponseWriter, r *http.Request, connection *Connection, name string, info os.FileInfo, inline bool, share *dataprovider.Share, ) (int, error) { connection.User.CheckFsRoot(connection.ID) //nolint:errcheck err := checkDownloadFileFromShare(share, info) if err != nil { return http.StatusBadRequest, err } rangeHeader := r.Header.Get("Range") if rangeHeader != "" && checkIfRange(r, info.ModTime()) == condFalse { rangeHeader = "" } offset := int64(0) size := info.Size() responseStatus := http.StatusOK if strings.HasPrefix(rangeHeader, "bytes=") { if strings.Contains(rangeHeader, ",") { return http.StatusRequestedRangeNotSatisfiable, fmt.Errorf("unsupported range %q", rangeHeader) } offset, size, err = parseRangeRequest(rangeHeader[6:], size) if err != nil { return http.StatusRequestedRangeNotSatisfiable, err } responseStatus = http.StatusPartialContent } reader, err := connection.getFileReader(name, offset, r.Method) if err != nil { return getMappedStatusCode(err), fmt.Errorf("unable to read file %q: %v", name, err) } defer reader.Close() w.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat)) if checkPreconditions(w, r, info.ModTime()) { return 0, fmt.Errorf("%v", http.StatusText(http.StatusPreconditionFailed)) } ctype := mime.TypeByExtension(path.Ext(name)) if ctype == "" { ctype = "application/octet-stream" } if responseStatus == http.StatusPartialContent { w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, info.Size())) } w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) w.Header().Set("Content-Type", ctype) if !inline { w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", path.Base(name))) } w.Header().Set("Accept-Ranges", "bytes") w.WriteHeader(responseStatus) if r.Method != http.MethodHead { _, err = io.CopyN(w, reader, size) if err != nil { if share != nil { dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck } connection.Log(logger.LevelDebug, "error reading file to download: %v", err) panic(http.ErrAbortHandler) } } return http.StatusOK, nil } func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) bool { if checkIfUnmodifiedSince(r, modtime) == condFalse { w.WriteHeader(http.StatusPreconditionFailed) return true } if checkIfModifiedSince(r, modtime) == condFalse { w.WriteHeader(http.StatusNotModified) return true } return false } func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult { ius := r.Header.Get("If-Unmodified-Since") if ius == "" || isZeroTime(modtime) { return condNone } t, err := http.ParseTime(ius) if err != nil { return condNone } // The Last-Modified header truncates sub-second precision so // the modtime needs to be truncated too. modtime = modtime.Truncate(time.Second) if modtime.Before(t) || modtime.Equal(t) { return condTrue } return condFalse } func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult { if r.Method != http.MethodGet && r.Method != http.MethodHead { return condNone } ims := r.Header.Get("If-Modified-Since") if ims == "" || isZeroTime(modtime) { return condNone } t, err := http.ParseTime(ims) if err != nil { return condNone } // The Last-Modified header truncates sub-second precision so // the modtime needs to be truncated too. modtime = modtime.Truncate(time.Second) if modtime.Before(t) || modtime.Equal(t) { return condFalse } return condTrue } func checkIfRange(r *http.Request, modtime time.Time) condResult { if r.Method != http.MethodGet && r.Method != http.MethodHead { return condNone } ir := r.Header.Get("If-Range") if ir == "" { return condNone } if modtime.IsZero() { return condFalse } t, err := http.ParseTime(ir) if err != nil { return condFalse } if modtime.Unix() == t.Unix() { return condTrue } return condFalse } func parseRangeRequest(bytesRange string, size int64) (int64, int64, error) { var start, end int64 var err error values := strings.Split(bytesRange, "-") if values[0] == "" { start = -1 } else { start, err = strconv.ParseInt(values[0], 10, 64) if err != nil { return start, size, err } } if len(values) >= 2 { if values[1] != "" { end, err = strconv.ParseInt(values[1], 10, 64) if err != nil { return start, size, err } if end >= size { end = size - 1 } } } if start == -1 && end == 0 { return 0, 0, fmt.Errorf("unsupported range %q", bytesRange) } if end > 0 { if start == -1 { // we have something like -500 start = size - end size = end // start cannot be < 0 here, we did end = size -1 above } else { // we have something like 500-600 size = end - start + 1 if size < 0 { return 0, 0, fmt.Errorf("unacceptable range %q", bytesRange) } } return start, size, nil } // we have something like 500- size -= start if size < 0 { return 0, 0, fmt.Errorf("unacceptable range %q", bytesRange) } return start, size, err } func handleDefenderEventLoginFailed(ipAddr string, err error) error { event := common.HostEventLoginFailed if errors.Is(err, util.ErrNotFound) { event = common.HostEventUserNotFound err = dataprovider.ErrInvalidCredentials } common.AddDefenderEvent(ipAddr, common.ProtocolHTTP, event) common.DelayLogin(err) return err } func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err error, r *http.Request) { metric.AddLoginAttempt(loginMethod) var protocol string switch loginMethod { case dataprovider.LoginMethodIDP: protocol = common.ProtocolOIDC default: protocol = common.ProtocolHTTP } if err == nil { logger.LoginLog(user.Username, ip, loginMethod, protocol, "", r.UserAgent(), r.TLS != nil, "") plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, protocol, user.Username, ip, "", nil) common.DelayLogin(nil) } else if err != common.ErrInternalFailure && err != common.ErrNoCredentials { logger.ConnectionFailedLog(user.Username, ip, loginMethod, protocol, err.Error()) err = handleDefenderEventLoginFailed(ip, err) logEv := notifier.LogEventTypeLoginFailed if errors.Is(err, util.ErrNotFound) { logEv = notifier.LogEventTypeLoginNoUser } plugin.Handler.NotifyLogEvent(logEv, protocol, user.Username, ip, "", err) } metric.AddLoginResult(loginMethod, err) dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err) } func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions, isOIDCLogin bool) error { if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) { logger.Info(logSender, connectionID, "cannot login user %q, protocol HTTP is not allowed", user.Username) return util.NewI18nError( fmt.Errorf("protocol HTTP is not allowed for user %q", user.Username), util.I18nErrorProtocolForbidden, ) } if !isLoggedInWithOIDC(r) && !isOIDCLogin && !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) { logger.Info(logSender, connectionID, "cannot login user %q, password login method is not allowed", user.Username) return util.NewI18nError( fmt.Errorf("login method password is not allowed for user %q", user.Username), util.I18nErrorPwdLoginForbidden, ) } if checkSessions && user.MaxSessions > 0 { activeSessions := common.Connections.GetActiveSessions(user.Username) if activeSessions >= user.MaxSessions { logger.Info(logSender, connectionID, "authentication refused for user: %q, too many open sessions: %v/%v", user.Username, activeSessions, user.MaxSessions) return util.NewI18nError(fmt.Errorf("too many open sessions: %v", activeSessions), util.I18nError429Message) } } if !user.IsLoginFromAddrAllowed(r.RemoteAddr) { logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", user.Username, r.RemoteAddr) return util.NewI18nError( fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, r.RemoteAddr), util.I18nErrorIPForbidden, ) } return nil } func getActiveAdmin(username, ipAddr string) (dataprovider.Admin, error) { admin, err := dataprovider.AdminExists(username) if err != nil { return admin, err } if err := admin.CanLogin(ipAddr); err != nil { return admin, util.NewRecordNotFoundError(fmt.Sprintf("admin %q cannot login: %v", username, err)) } return admin, nil } func getActiveUser(username string, r *http.Request) (dataprovider.User, error) { user, err := dataprovider.GetUserWithGroupSettings(username, "") if err != nil { return user, err } if err := user.CheckLoginConditions(); err != nil { return user, util.NewRecordNotFoundError(fmt.Sprintf("user %q cannot login: %v", username, err)) } if err := checkHTTPClientUser(&user, r, xid.New().String(), false, false); err != nil { return user, util.NewRecordNotFoundError(fmt.Sprintf("user %q cannot login: %v", username, err)) } return user, nil } func handleForgotPassword(r *http.Request, username string, isAdmin bool) error { var emails []string var subject string var err error var admin dataprovider.Admin var user dataprovider.User if username == "" { return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired) } if isAdmin { admin, err = getActiveAdmin(username, util.GetIPFromRemoteAddress(r.RemoteAddr)) if admin.Email != "" { emails = []string{admin.Email} } subject = fmt.Sprintf("Email Verification Code for admin %q", username) } else { user, err = getActiveUser(username, r) emails = user.GetEmailAddresses() subject = fmt.Sprintf("Email Verification Code for user %q", username) if err == nil { if !isUserAllowedToResetPassword(r, &user) { return util.NewI18nError( util.NewValidationError("you are not allowed to reset your password"), util.I18nErrorPwdResetForbidded, ) } } } if err != nil { if errors.Is(err, util.ErrNotFound) { handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), err) //nolint:errcheck logger.Debug(logSender, middleware.GetReqID(r.Context()), "username %q does not exists or cannot login, reset password request silently ignored, is admin? %t, err: %v", username, isAdmin, err) return nil } return util.NewI18nError(util.NewGenericError("Error retrieving your account, please try again later"), util.I18nErrorGetUser) } if len(emails) == 0 { return util.NewI18nError( util.NewValidationError("Your account does not have an email address, it is not possible to reset your password by sending an email verification code"), util.I18nErrorPwdResetNoEmail, ) } c := newResetCode(username, isAdmin) body := new(bytes.Buffer) data := make(map[string]string) data["Code"] = c.Code if err := smtp.RenderPasswordResetTemplate(body, data); err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to render password reset template: %v", err) return util.NewGenericError("Unable to render password reset template") } startTime := time.Now() if err := smtp.SendEmail(emails, nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil { logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to send password reset code via email: %v, elapsed: %v", err, time.Since(startTime)) return util.NewI18nError( util.NewGenericError(fmt.Sprintf("Error sending confirmation code via email: %v", err)), util.I18nErrorPwdResetSendEmail, ) } logger.Debug(logSender, middleware.GetReqID(r.Context()), "reset code sent via email to %q, emails: %+v, is admin? %v, elapsed: %v", username, emails, isAdmin, time.Since(startTime)) return resetCodesMgr.Add(c) } func handleResetPassword(r *http.Request, code, newPassword, confirmPassword string, isAdmin bool) ( *dataprovider.Admin, *dataprovider.User, error, ) { var admin dataprovider.Admin var user dataprovider.User var err error if newPassword == "" { return &admin, &user, util.NewValidationError("please set a password") } if code == "" { return &admin, &user, util.NewValidationError("please set a confirmation code") } if newPassword != confirmPassword { return &admin, &user, util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch) } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) resetCode, err := resetCodesMgr.Get(code) if err != nil { handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck return &admin, &user, util.NewValidationError("confirmation code not found") } if resetCode.IsAdmin != isAdmin { return &admin, &user, util.NewValidationError("invalid confirmation code") } if isAdmin { admin, err = getActiveAdmin(resetCode.Username, ipAddr) if err != nil { return &admin, &user, util.NewValidationError("unable to associate the confirmation code with an existing admin") } admin.Password = newPassword admin.Filters.RequirePasswordChange = false err = dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, ipAddr, admin.Role) if err != nil { return &admin, &user, util.NewGenericError(fmt.Sprintf("unable to set the new password: %v", err)) } err = resetCodesMgr.Delete(code) return &admin, &user, err } user, err = getActiveUser(resetCode.Username, r) if err != nil { return &admin, &user, util.NewValidationError("Unable to associate the confirmation code with an existing user") } if !isUserAllowedToResetPassword(r, &user) { return &admin, &user, util.NewI18nError( util.NewValidationError("you are not allowed to reset your password"), util.I18nErrorPwdResetForbidded, ) } err = dataprovider.UpdateUserPassword(user.Username, newPassword, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role) if err == nil { err = resetCodesMgr.Delete(code) } user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now()) user.Filters.RequirePasswordChange = false return &admin, &user, err } func isUserAllowedToResetPassword(r *http.Request, user *dataprovider.User) bool { if !user.CanResetPassword() { return false } if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) { return false } if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) { return false } if !user.IsLoginFromAddrAllowed(r.RemoteAddr) { return false } return true } func getProtocolFromRequest(r *http.Request) string { if isLoggedInWithOIDC(r) { return common.ProtocolOIDC } return common.ProtocolHTTP } func hideConfidentialData(claims *jwt.Claims, r *http.Request) bool { if !claims.HasPerm(dataprovider.PermAdminAny) { return true } return r.URL.Query().Get("confidential_data") != "1" } func responseControllerDeadlines(rc *http.ResponseController, read, write time.Time) { if err := rc.SetReadDeadline(read); err != nil { logger.Error(logSender, "", "unable to set read timeout to %s: %v", read, err) } if err := rc.SetWriteDeadline(write); err != nil { logger.Error(logSender, "", "unable to set write timeout to %s: %v", write, err) } } ================================================ FILE: internal/httpd/auth_utils.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "crypto/rand" "errors" "fmt" "net/http" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) type tokenAudience = string const ( tokenAudienceWebAdmin tokenAudience = "WebAdmin" tokenAudienceWebClient tokenAudience = "WebClient" tokenAudienceWebShare tokenAudience = "WebShare" tokenAudienceWebAdminPartial tokenAudience = "WebAdminPartial" tokenAudienceWebClientPartial tokenAudience = "WebClientPartial" tokenAudienceAPI tokenAudience = "API" tokenAudienceAPIUser tokenAudience = "APIUser" tokenAudienceCSRF tokenAudience = "CSRF" tokenAudienceOAuth2 tokenAudience = "OAuth2" tokenAudienceWebLogin tokenAudience = "WebLogin" ) const ( tokenValidationModeDefault = 0 tokenValidationModeNoIPMatch = 1 tokenValidationModeUserSignature = 2 ) const ( basicRealm = "Basic realm=\"SFTPGo\"" ) var ( apiTokenDuration = 20 * time.Minute cookieTokenDuration = 20 * time.Minute shareTokenDuration = 2 * time.Hour // csrf token duration is greater than normal token duration to reduce issues // with the login form csrfTokenDuration = 4 * time.Hour cookieRefreshThreshold = 10 * time.Minute maxTokenDuration = 12 * time.Hour tokenValidationMode = tokenValidationModeDefault ) func isTokenDurationValid(minutes int) bool { return minutes >= 1 && minutes <= 720 } func updateTokensDuration(api, cookie, share int) { if isTokenDurationValid(api) { apiTokenDuration = time.Duration(api) * time.Minute } if isTokenDurationValid(cookie) { cookieTokenDuration = time.Duration(cookie) * time.Minute cookieRefreshThreshold = cookieTokenDuration / 2 if cookieTokenDuration > csrfTokenDuration { csrfTokenDuration = cookieTokenDuration } } if isTokenDurationValid(share) { shareTokenDuration = time.Duration(share) * time.Minute } logger.Debug(logSender, "", "API token duration %s, cookie token duration %s, cookie refresh threshold %s, share token duration %s, csrf token duration %s", apiTokenDuration, cookieTokenDuration, cookieRefreshThreshold, shareTokenDuration, csrfTokenDuration) } func getTokenDuration(audience tokenAudience) time.Duration { switch audience { case tokenAudienceWebShare: return shareTokenDuration case tokenAudienceWebLogin, tokenAudienceCSRF: return csrfTokenDuration case tokenAudienceAPI, tokenAudienceAPIUser: return apiTokenDuration case tokenAudienceWebAdmin, tokenAudienceWebClient: return cookieTokenDuration case tokenAudienceWebAdminPartial, tokenAudienceWebClientPartial, tokenAudienceOAuth2: return 5 * time.Minute default: logger.Error(logSender, "", "token duration not handled for audience: %q", audience) return 20 * time.Minute } } func getMaxCookieDuration() time.Duration { result := csrfTokenDuration if shareTokenDuration > result { result = shareTokenDuration } if cookieTokenDuration > result { result = cookieTokenDuration } return result } func hasUserAudience(claims *jwt.Claims) bool { return claims.HasAnyAudience([]string{tokenAudienceWebClient, tokenAudienceAPIUser}) } func createAndSetCookie(w http.ResponseWriter, r *http.Request, claims *jwt.Claims, tokenAuth *jwt.Signer, audience tokenAudience, ip string, ) error { duration := getTokenDuration(audience) token, err := tokenAuth.SignWithParams(claims, audience, ip, duration) if err != nil { return err } resp := claims.BuildTokenResponse(token) var basePath string if audience == tokenAudienceWebAdmin || audience == tokenAudienceWebAdminPartial { basePath = webBaseAdminPath } else { basePath = webBaseClientPath } setCookie(w, r, basePath, resp.Token, duration) return nil } func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue string, duration time.Duration) { http.SetCookie(w, &http.Cookie{ Name: jwt.CookieKey, Value: cookieValue, Path: cookiePath, Expires: time.Now().Add(duration), MaxAge: int(duration / time.Second), HttpOnly: true, Secure: isTLS(r), SameSite: http.SameSiteStrictMode, }) } func removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) { invalidateToken(r) http.SetCookie(w, &http.Cookie{ Name: jwt.CookieKey, Value: "", Path: cookiePath, Expires: time.Unix(0, 0), MaxAge: -1, HttpOnly: true, Secure: isTLS(r), SameSite: http.SameSiteStrictMode, }) w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) } func oidcTokenFromContext(r *http.Request) string { if token, ok := r.Context().Value(oidcGeneratedToken).(string); ok { return token } return "" } func isTLS(r *http.Request) bool { if r.TLS != nil { return true } if proto, ok := r.Context().Value(forwardedProtoKey).(string); ok { return proto == "https" //nolint:goconst } return false } func isTokenInvalidated(r *http.Request) bool { var findTokenFns []func(r *http.Request) string findTokenFns = append(findTokenFns, jwt.TokenFromHeader) findTokenFns = append(findTokenFns, jwt.TokenFromCookie) findTokenFns = append(findTokenFns, oidcTokenFromContext) isTokenFound := false for _, fn := range findTokenFns { token := fn(r) if token != "" { isTokenFound = true if invalidatedJWTTokens.Get(token) { return true } } } return !isTokenFound } func invalidateToken(r *http.Request) { tokenString := jwt.TokenFromHeader(r) if tokenString != "" { invalidateTokenString(r, tokenString, apiTokenDuration) } tokenString = jwt.TokenFromCookie(r) if tokenString != "" { invalidateTokenString(r, tokenString, getMaxCookieDuration()) } } func invalidateTokenString(r *http.Request, tokenString string, fallbackDuration time.Duration) { token, err := jwt.FromContext(r.Context()) if err != nil { invalidatedJWTTokens.Add(tokenString, time.Now().Add(fallbackDuration).UTC()) return } invalidatedJWTTokens.Add(tokenString, token.Expiry.Time().Add(1*time.Minute).UTC()) } func getUserFromToken(r *http.Request) *dataprovider.User { user := &dataprovider.User{} claims, err := jwt.FromContext(r.Context()) if err != nil { return user } user.Username = claims.Username user.Filters.WebClient = claims.Permissions user.Role = claims.Role return user } func getAdminFromToken(r *http.Request) *dataprovider.Admin { admin := &dataprovider.Admin{} claims, err := jwt.FromContext(r.Context()) if err != nil { return admin } admin.Username = claims.Username admin.Permissions = claims.Permissions admin.Filters.Preferences.HideUserPageSections = claims.HideUserPageSections admin.Role = claims.Role return admin } func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwt.Signer, tokenID, basePath, ip string, ) { c := jwt.NewClaims(tokenAudienceWebLogin, ip, getTokenDuration(tokenAudienceWebLogin)) c.ID = tokenID resp, err := c.GenerateTokenResponse(csrfTokenAuth) if err != nil { return } setCookie(w, r, basePath, resp.Token, csrfTokenDuration) } func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwt.Signer, tokenID, basePath string, ) string { ip := util.GetIPFromRemoteAddress(r.RemoteAddr) claims := jwt.NewClaims(tokenAudienceCSRF, ip, csrfTokenDuration) claims.ID = rand.Text() if tokenID != "" { createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip) claims.Ref = tokenID } else { if c, err := jwt.FromContext(r.Context()); err == nil { claims.Ref = c.ID } else { logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err) } } tokenString, err := csrfTokenAuth.Sign(claims) if err != nil { logger.Debug(logSender, "", "unable to create CSRF token: %v", err) return "" } return tokenString } func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwt.Signer) error { tokenString := r.Form.Get(csrfFormToken) token, err := jwt.VerifyToken(csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err) return fmt.Errorf("unable to verify form token: %v", err) } if !token.Audience.Contains(tokenAudienceCSRF) { logger.Debug(logSender, "", "error validating CSRF token audience") return errors.New("the form token is not valid") } if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { logger.Debug(logSender, "", "error validating CSRF token IP audience") return errors.New("the form token is not valid") } return checkCSRFTokenRef(r, token) } func checkCSRFTokenRef(r *http.Request, token *jwt.Claims) error { claims, err := jwt.FromContext(r.Context()) if err != nil { logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err) return err } if token.ID == "" { logger.Debug(logSender, "", "error validating CSRF token, missing reference") return errors.New("the form token is not valid") } if claims.ID != token.Ref { logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.ID, token.ID) return errors.New("unexpected form token") } return nil } func verifyLoginCookie(r *http.Request) error { token, err := jwt.FromContext(r.Context()) if err != nil { logger.Debug(logSender, "", "error getting login token: %v", err) return errInvalidToken } if isTokenInvalidated(r) { logger.Debug(logSender, "", "the login token has been invalidated") return errInvalidToken } if !token.Audience.Contains(tokenAudienceWebLogin) { logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.ID, tokenAudienceWebLogin) return errInvalidToken } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := validateIPForToken(token, ipAddr); err != nil { return err } return nil } func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwt.Signer) error { if err := verifyLoginCookie(r); err != nil { return err } if err := verifyCSRFToken(r, csrfTokenAuth); err != nil { return err } return nil } func createOAuth2Token(csrfTokenAuth *jwt.Signer, state, ip string) string { claims := jwt.NewClaims(tokenAudienceOAuth2, ip, getTokenDuration(tokenAudienceOAuth2)) claims.ID = state tokenString, err := csrfTokenAuth.Sign(claims) if err != nil { logger.Debug(logSender, "", "unable to create OAuth2 token: %v", err) return "" } return tokenString } func verifyOAuth2Token(csrfTokenAuth *jwt.Signer, tokenString, ip string) (string, error) { token, err := jwt.VerifyToken(csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err) return "", util.NewI18nError( fmt.Errorf("unable to verify OAuth2 state: %v", err), util.I18nOAuth2ErrorVerifyState, ) } if !token.Audience.Contains(tokenAudienceOAuth2) { logger.Debug(logSender, "", "error validating OAuth2 token audience") return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } if err := validateIPForToken(token, ip); err != nil { logger.Debug(logSender, "", "error validating OAuth2 token IP audience") return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } if token.ID != "" { return token.ID, nil } logger.Debug(logSender, "", "jti not found in OAuth2 token") return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } func validateIPForToken(token *jwt.Claims, ip string) error { if tokenValidationMode&tokenValidationModeNoIPMatch == 0 { if !token.Audience.Contains(ip) { return errInvalidToken } } return nil } func checkTokenSignature(r *http.Request, token *jwt.Claims) error { if _, ok := r.Context().Value(oidcTokenKey).(string); ok { return nil } var err error if tokenValidationMode&tokenValidationModeUserSignature != 0 { for _, audience := range token.Audience { switch audience { case tokenAudienceAPI, tokenAudienceWebAdmin: err = validateSignatureForToken(token, dataprovider.GetAdminSignature) case tokenAudienceAPIUser, tokenAudienceWebClient: err = validateSignatureForToken(token, dataprovider.GetUserSignature) } } } if err != nil { invalidateToken(r) } return err } func validateSignatureForToken(token *jwt.Claims, getter func(string) (string, error)) error { signature, err := getter(token.Username) if err != nil { logger.Debug(logSender, "", "unable to get signature for username %q: %v", token.Username, err) return errInvalidToken } if signature != "" && signature == token.Subject { return nil } logger.Debug(logSender, "", "signature mismatch for username %q, signature %q, token signature %q", token.Username, signature, token.Subject) return errInvalidToken } ================================================ FILE: internal/httpd/file.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "io" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/vfs" ) type httpdFile struct { *common.BaseTransfer writer io.WriteCloser reader io.ReadCloser isFinished bool } func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader) *httpdFile { var writer io.WriteCloser var reader io.ReadCloser if baseTransfer.File != nil { writer = baseTransfer.File reader = baseTransfer.File } else if pipeWriter != nil { writer = pipeWriter } else if pipeReader != nil { reader = pipeReader } return &httpdFile{ BaseTransfer: baseTransfer, writer: writer, reader: reader, isFinished: false, } } // Read reads the contents to downloads. func (f *httpdFile) Read(p []byte) (n int, err error) { if f.AbortTransfer.Load() { err := f.GetAbortError() f.TransferError(err) return 0, err } f.Connection.UpdateLastActivity() n, err = f.reader.Read(p) f.BytesSent.Add(int64(n)) if err == nil { err = f.CheckRead() } if err != nil && err != io.EOF { f.TransferError(err) err = f.ConvertError(err) return } f.HandleThrottle() return } // Write writes the contents to upload func (f *httpdFile) Write(p []byte) (n int, err error) { if f.AbortTransfer.Load() { err := f.GetAbortError() f.TransferError(err) return 0, err } f.Connection.UpdateLastActivity() n, err = f.writer.Write(p) f.BytesReceived.Add(int64(n)) if err == nil { err = f.CheckWrite() } if err != nil { f.TransferError(err) err = f.ConvertError(err) return } f.HandleThrottle() return } // Close closes the current transfer func (f *httpdFile) Close() error { if err := f.setFinished(); err != nil { return err } err := f.closeIO() errBaseClose := f.BaseTransfer.Close() if errBaseClose != nil { err = errBaseClose } return f.Connection.GetFsError(f.Fs, err) } func (f *httpdFile) closeIO() error { var err error if f.File != nil { err = f.File.Close() } else if f.writer != nil { err = f.writer.Close() f.Lock() // we set ErrTransfer here so quota is not updated, in this case the uploads are atomic if err != nil && f.ErrTransfer == nil { f.ErrTransfer = err } f.Unlock() } else if f.reader != nil { err = f.reader.Close() if metadater, ok := f.reader.(vfs.Metadater); ok { f.SetMetadata(metadater.Metadata()) } } return err } func (f *httpdFile) setFinished() error { f.Lock() defer f.Unlock() if f.isFinished { return common.ErrTransferClosed } f.isFinished = true return nil } ================================================ FILE: internal/httpd/flash.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/base64" "encoding/json" "net/http" "time" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( flashCookieName = "message" ) func newFlashMessage(errorStrig, i18nMessage string) flashMessage { return flashMessage{ ErrorString: errorStrig, I18nMessage: i18nMessage, } } type flashMessage struct { ErrorString string `json:"error"` I18nMessage string `json:"message"` } func (m *flashMessage) getI18nError() *util.I18nError { if m.ErrorString == "" && m.I18nMessage == "" { return nil } return util.NewI18nError( util.NewGenericError(m.ErrorString), m.I18nMessage, ) } func setFlashMessage(w http.ResponseWriter, r *http.Request, message flashMessage) { value, err := json.Marshal(message) if err != nil { return } http.SetCookie(w, &http.Cookie{ Name: flashCookieName, Value: base64.URLEncoding.EncodeToString(value), Path: "/", Expires: time.Now().Add(60 * time.Second), MaxAge: 60, HttpOnly: true, Secure: isTLS(r), SameSite: http.SameSiteLaxMode, }) w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) } func getFlashMessage(w http.ResponseWriter, r *http.Request) flashMessage { var msg flashMessage cookie, err := r.Cookie(flashCookieName) if err != nil { return msg } http.SetCookie(w, &http.Cookie{ Name: flashCookieName, Value: "", Path: "/", Expires: time.Unix(0, 0), MaxAge: -1, HttpOnly: true, Secure: isTLS(r), SameSite: http.SameSiteLaxMode, }) value, err := base64.URLEncoding.DecodeString(cookie.Value) if err != nil { return msg } err = json.Unmarshal(value, &msg) if err != nil { return flashMessage{} } return msg } ================================================ FILE: internal/httpd/flash_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/base64" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/util" ) func TestFlashMessages(t *testing.T) { rr := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, "/url", nil) require.NoError(t, err) message := flashMessage{ ErrorString: "error", I18nMessage: util.I18nChangePwdTitle, } setFlashMessage(rr, req, message) value, err := json.Marshal(message) assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("%v=%v", flashCookieName, base64.URLEncoding.EncodeToString(value))) msg := getFlashMessage(rr, req) assert.Equal(t, message, msg) assert.Equal(t, util.I18nChangePwdTitle, msg.getI18nError().Message) req.Header.Set("Cookie", fmt.Sprintf("%v=%v", flashCookieName, "a")) msg = getFlashMessage(rr, req) assert.Empty(t, msg) req.Header.Set("Cookie", fmt.Sprintf("%v=%v", flashCookieName, "YQ==")) msg = getFlashMessage(rr, req) assert.Empty(t, msg) } ================================================ FILE: internal/httpd/handler.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "io" "net/http" "os" "path" "strings" "sync" "sync/atomic" "time" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) // Connection details for a HTTP connection used to inteact with an SFTPGo filesystem type Connection struct { *common.BaseConnection request *http.Request rc *http.ResponseController } func newConnection(conn *common.BaseConnection, w http.ResponseWriter, r *http.Request) *Connection { rc := http.NewResponseController(w) responseControllerDeadlines(rc, time.Time{}, time.Time{}) return &Connection{ BaseConnection: conn, request: r, rc: rc, } } // GetClientVersion returns the connected client's version. func (c *Connection) GetClientVersion() string { if c.request != nil { return c.request.UserAgent() } return "" } // GetLocalAddress returns local connection address func (c *Connection) GetLocalAddress() string { return util.GetHTTPLocalAddress(c.request) } // GetRemoteAddress returns the connected client's address func (c *Connection) GetRemoteAddress() string { if c.request != nil { return c.request.RemoteAddr } return "" } // Disconnect closes the active transfer func (c *Connection) Disconnect() (err error) { if c.rc != nil { responseControllerDeadlines(c.rc, time.Now().Add(5*time.Second), time.Now().Add(5*time.Second)) } return c.SignalTransfersAbort() } // GetCommand returns the request method func (c *Connection) GetCommand() string { if c.request != nil { return strings.ToUpper(c.request.Method) } return "" } // Stat returns a FileInfo describing the named file/directory, or an error, // if any happens func (c *Connection) Stat(name string, mode int) (os.FileInfo, error) { c.UpdateLastActivity() if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(name)) { return nil, c.GetPermissionDeniedError() } fi, err := c.DoStat(name, mode, true) if err != nil { return nil, err } return fi, err } // ReadDir returns a list of directory entries func (c *Connection) ReadDir(name string) (vfs.DirLister, error) { c.UpdateLastActivity() return c.ListDir(name) } func (c *Connection) getFileReader(name string, offset int64, method string) (io.ReadCloser, error) { c.UpdateLastActivity() if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { c.Log(logger.LevelInfo, "denying file read due to transfer count limits") return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) } transferQuota := c.GetTransferQuota() if !transferQuota.HasDownloadSpace() { c.Log(logger.LevelInfo, "denying file read due to quota limits") return nil, util.NewI18nError(c.GetReadQuotaExceededError(), util.I18nErrorQuotaRead) } if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(name)) { return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) } if ok, policy := c.User.IsFileAllowed(name); !ok { c.Log(logger.LevelWarn, "reading file %q is not allowed", name) return nil, util.NewI18nError(c.GetErrorForDeniedFile(policy), util.I18nError403Message) } fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { return nil, err } if method != http.MethodHead { if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreDownload, p, name, 0, 0); err != nil { c.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", name, err) return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) } } file, r, cancelFn, err := fs.Open(p, offset) if err != nil { c.Log(logger.LevelError, "could not open file %q for reading: %+v", p, err) return nil, c.GetFsError(fs, err) } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota) return newHTTPDFile(baseTransfer, nil, r), nil } func (c *Connection) getFileWriter(name string) (io.WriteCloser, error) { c.UpdateLastActivity() if ok, _ := c.User.IsFileAllowed(name); !ok { c.Log(logger.LevelWarn, "writing file %q is not allowed", name) return nil, c.GetPermissionDeniedError() } fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { return nil, err } filePath := p if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { filePath = fs.GetAtomicUploadPath(p) } stat, statErr := fs.Lstat(p) if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(name)) { return nil, c.GetPermissionDeniedError() } return c.handleUploadFile(fs, p, filePath, name, true, 0) } if statErr != nil { c.Log(logger.LevelError, "error performing file stat %q: %+v", p, statErr) return nil, c.GetFsError(fs, statErr) } // This happen if we upload a file that has the same name of an existing directory if stat.IsDir() { c.Log(logger.LevelError, "attempted to open a directory for writing to: %q", p) return nil, c.GetOpUnsupportedError() } if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(name)) { return nil, c.GetPermissionDeniedError() } if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { _, _, err = fs.Rename(p, filePath, 0) if err != nil { c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v", p, filePath, err) return nil, c.GetFsError(fs, err) } } return c.handleUploadFile(fs, p, filePath, name, false, stat.Size()) } func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, requestPath string, isNewFile bool, fileSize int64) (io.WriteCloser, error) { if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { c.Log(logger.LevelInfo, "denying file write due to transfer count limits") return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) } diskQuota, transferQuota := c.HasSpace(isNewFile, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, common.ErrQuotaExceeded } _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, fileSize, os.O_TRUNC) if err != nil { c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) return nil, c.GetPermissionDeniedError() } maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.GetCreateChecks(requestPath, isNewFile, false)) if err != nil { c.Log(logger.LevelError, "error opening existing file, source: %q, err: %+v", filePath, err) return nil, c.GetFsError(fs, err) } initialSize := int64(0) truncatedSize := int64(0) // bytes truncated and not included in quota if !isNewFile { if vfs.HasTruncateSupport(fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false) } else { dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { initialSize = fileSize truncatedSize = fileSize } if maxWriteSize > 0 { maxWriteSize += fileSize } } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs, transferQuota) return newHTTPDFile(baseTransfer, w, nil), nil } func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttledReader { t := &throttledReader{ id: conn.GetTransferID(), limit: limit, r: r, start: time.Now(), conn: conn, } t.bytesRead.Store(0) t.abortTransfer.Store(false) conn.AddTransfer(t) return t } type throttledReader struct { bytesRead atomic.Int64 id int64 limit int64 r io.ReadCloser abortTransfer atomic.Bool start time.Time conn *Connection mu sync.Mutex errAbort error } func (t *throttledReader) GetID() int64 { return t.id } func (t *throttledReader) GetType() int { return common.TransferUpload } func (t *throttledReader) GetSize() int64 { return t.bytesRead.Load() } func (t *throttledReader) GetDownloadedSize() int64 { return 0 } func (t *throttledReader) GetUploadedSize() int64 { return t.bytesRead.Load() } func (t *throttledReader) GetVirtualPath() string { return "**reading request body**" } func (t *throttledReader) GetStartTime() time.Time { return t.start } func (t *throttledReader) GetAbortError() error { t.mu.Lock() defer t.mu.Unlock() if t.errAbort != nil { return t.errAbort } return common.ErrTransferAborted } func (t *throttledReader) SignalClose(err error) { t.mu.Lock() t.errAbort = err t.mu.Unlock() t.abortTransfer.Store(true) } func (t *throttledReader) GetTruncatedSize() int64 { return 0 } func (t *throttledReader) HasSizeLimit() bool { return false } func (t *throttledReader) Truncate(_ string, _ int64) (int64, error) { return 0, vfs.ErrVfsUnsupported } func (t *throttledReader) GetRealFsPath(_ string) string { return "" } func (t *throttledReader) GetFsPath() string { return "" } func (t *throttledReader) SetTimes(_ string, _ time.Time, _ time.Time) bool { return false } func (t *throttledReader) Read(p []byte) (n int, err error) { if t.abortTransfer.Load() { return 0, t.GetAbortError() } t.conn.UpdateLastActivity() n, err = t.r.Read(p) if t.limit > 0 { t.bytesRead.Add(int64(n)) trasferredBytes := t.bytesRead.Load() elapsed := time.Since(t.start).Nanoseconds() / 1000000 wantedElapsed := 1000 * (trasferredBytes / 1024) / t.limit if wantedElapsed > elapsed { toSleep := time.Duration(wantedElapsed - elapsed) time.Sleep(toSleep * time.Millisecond) } } return } func (t *throttledReader) Close() error { return t.r.Close() } ================================================ FILE: internal/httpd/httpd.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package httpd implements REST API and Web interface for SFTPGo. // The OpenAPI 3 schema for the supported API can be found inside the source tree: // https://github.com/drakkan/sftpgo/blob/main/openapi/openapi.yaml package httpd import ( "crypto/sha256" "errors" "fmt" "net" "net/http" "net/url" "os" "path" "path/filepath" "runtime" "strings" "sync" "time" "github.com/go-chi/chi/v5" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) const ( logSender = "httpd" tokenPath = "/api/v2/token" logoutPath = "/api/v2/logout" userTokenPath = "/api/v2/user/token" userLogoutPath = "/api/v2/user/logout" activeConnectionsPath = "/api/v2/connections" quotasBasePath = "/api/v2/quotas" userPath = "/api/v2/users" versionPath = "/api/v2/version" folderPath = "/api/v2/folders" groupPath = "/api/v2/groups" serverStatusPath = "/api/v2/status" dumpDataPath = "/api/v2/dumpdata" loadDataPath = "/api/v2/loaddata" defenderHosts = "/api/v2/defender/hosts" adminPath = "/api/v2/admins" adminPwdPath = "/api/v2/admin/changepwd" adminProfilePath = "/api/v2/admin/profile" userPwdPath = "/api/v2/user/changepwd" userDirsPath = "/api/v2/user/dirs" userFilesPath = "/api/v2/user/files" userFileActionsPath = "/api/v2/user/file-actions" userStreamZipPath = "/api/v2/user/streamzip" userUploadFilePath = "/api/v2/user/files/upload" userFilesDirsMetadataPath = "/api/v2/user/files/metadata" apiKeysPath = "/api/v2/apikeys" adminTOTPConfigsPath = "/api/v2/admin/totp/configs" adminTOTPGeneratePath = "/api/v2/admin/totp/generate" adminTOTPValidatePath = "/api/v2/admin/totp/validate" adminTOTPSavePath = "/api/v2/admin/totp/save" admin2FARecoveryCodesPath = "/api/v2/admin/2fa/recoverycodes" userTOTPConfigsPath = "/api/v2/user/totp/configs" userTOTPGeneratePath = "/api/v2/user/totp/generate" userTOTPValidatePath = "/api/v2/user/totp/validate" userTOTPSavePath = "/api/v2/user/totp/save" user2FARecoveryCodesPath = "/api/v2/user/2fa/recoverycodes" userProfilePath = "/api/v2/user/profile" userSharesPath = "/api/v2/user/shares" retentionChecksPath = "/api/v2/retention/users/checks" fsEventsPath = "/api/v2/events/fs" providerEventsPath = "/api/v2/events/provider" logEventsPath = "/api/v2/events/logs" sharesPath = "/api/v2/shares" eventActionsPath = "/api/v2/eventactions" eventRulesPath = "/api/v2/eventrules" rolesPath = "/api/v2/roles" ipListsPath = "/api/v2/iplists" healthzPath = "/healthz" webRootPathDefault = "/" webBasePathDefault = "/web" webBasePathAdminDefault = "/web/admin" webBasePathClientDefault = "/web/client" webAdminSetupPathDefault = "/web/admin/setup" webAdminLoginPathDefault = "/web/admin/login" webAdminOIDCLoginPathDefault = "/web/admin/oidclogin" webOIDCRedirectPathDefault = "/web/oidc/redirect" webOAuth2RedirectPathDefault = "/web/oauth2/redirect" webOAuth2TokenPathDefault = "/web/admin/oauth2/token" webAdminTwoFactorPathDefault = "/web/admin/twofactor" webAdminTwoFactorRecoveryPathDefault = "/web/admin/twofactor-recovery" webLogoutPathDefault = "/web/admin/logout" webUsersPathDefault = "/web/admin/users" webUserPathDefault = "/web/admin/user" webConnectionsPathDefault = "/web/admin/connections" webFoldersPathDefault = "/web/admin/folders" webFolderPathDefault = "/web/admin/folder" webGroupsPathDefault = "/web/admin/groups" webGroupPathDefault = "/web/admin/group" webStatusPathDefault = "/web/admin/status" webAdminsPathDefault = "/web/admin/managers" webAdminPathDefault = "/web/admin/manager" webMaintenancePathDefault = "/web/admin/maintenance" webBackupPathDefault = "/web/admin/backup" webRestorePathDefault = "/web/admin/restore" webScanVFolderPathDefault = "/web/admin/quotas/scanfolder" webQuotaScanPathDefault = "/web/admin/quotas/scanuser" webChangeAdminPwdPathDefault = "/web/admin/changepwd" webAdminForgotPwdPathDefault = "/web/admin/forgot-password" webAdminResetPwdPathDefault = "/web/admin/reset-password" webAdminProfilePathDefault = "/web/admin/profile" webAdminMFAPathDefault = "/web/admin/mfa" webAdminEventRulesPathDefault = "/web/admin/eventrules" webAdminEventRulePathDefault = "/web/admin/eventrule" webAdminEventActionsPathDefault = "/web/admin/eventactions" webAdminEventActionPathDefault = "/web/admin/eventaction" webAdminRolesPathDefault = "/web/admin/roles" webAdminRolePathDefault = "/web/admin/role" webAdminTOTPGeneratePathDefault = "/web/admin/totp/generate" webAdminTOTPValidatePathDefault = "/web/admin/totp/validate" webAdminTOTPSavePathDefault = "/web/admin/totp/save" webAdminRecoveryCodesPathDefault = "/web/admin/recoverycodes" webTemplateUserDefault = "/web/admin/template/user" webTemplateFolderDefault = "/web/admin/template/folder" webDefenderPathDefault = "/web/admin/defender" webIPListsPathDefault = "/web/admin/ip-lists" webIPListPathDefault = "/web/admin/ip-list" webDefenderHostsPathDefault = "/web/admin/defender/hosts" webEventsPathDefault = "/web/admin/events" webEventsFsSearchPathDefault = "/web/admin/events/fs" webEventsProviderSearchPathDefault = "/web/admin/events/provider" webEventsLogSearchPathDefault = "/web/admin/events/logs" webConfigsPathDefault = "/web/admin/configs" webClientLoginPathDefault = "/web/client/login" webClientOIDCLoginPathDefault = "/web/client/oidclogin" webClientTwoFactorPathDefault = "/web/client/twofactor" webClientTwoFactorRecoveryPathDefault = "/web/client/twofactor-recovery" webClientFilesPathDefault = "/web/client/files" webClientFilePathDefault = "/web/client/file" webClientFileActionsPathDefault = "/web/client/file-actions" webClientSharesPathDefault = "/web/client/shares" webClientSharePathDefault = "/web/client/share" webClientEditFilePathDefault = "/web/client/editfile" webClientDirsPathDefault = "/web/client/dirs" webClientDownloadZipPathDefault = "/web/client/downloadzip" webClientProfilePathDefault = "/web/client/profile" webClientPingPathDefault = "/web/client/ping" webClientMFAPathDefault = "/web/client/mfa" webClientTOTPGeneratePathDefault = "/web/client/totp/generate" webClientTOTPValidatePathDefault = "/web/client/totp/validate" webClientTOTPSavePathDefault = "/web/client/totp/save" webClientRecoveryCodesPathDefault = "/web/client/recoverycodes" webChangeClientPwdPathDefault = "/web/client/changepwd" webClientLogoutPathDefault = "/web/client/logout" webClientPubSharesPathDefault = "/web/client/pubshares" webClientForgotPwdPathDefault = "/web/client/forgot-password" webClientResetPwdPathDefault = "/web/client/reset-password" webClientViewPDFPathDefault = "/web/client/viewpdf" webClientGetPDFPathDefault = "/web/client/getpdf" webClientExistPathDefault = "/web/client/exist" webClientTasksPathDefault = "/web/client/tasks" webStaticFilesPathDefault = "/static" webOpenAPIPathDefault = "/openapi" // MaxRestoreSize defines the max size for the loaddata input file MaxRestoreSize = 20 * 1048576 // 20 MB maxRequestSize = 1048576 // 1MB maxLoginBodySize = 262144 // 256 KB httpdMaxEditFileSize = 2 * 1048576 // 2 MB maxMultipartMem = 10 * 1048576 // 10 MB osWindows = "windows" otpHeaderCode = "X-SFTPGO-OTP" mTimeHeader = "X-SFTPGO-MTIME" acmeChallengeURI = "/.well-known/acme-challenge/" ) var ( certMgr *common.CertManager cleanupTicker *time.Ticker cleanupDone chan bool invalidatedJWTTokens tokenManager webRootPath string webBasePath string webBaseAdminPath string webBaseClientPath string webOIDCRedirectPath string webOAuth2RedirectPath string webOAuth2TokenPath string webAdminSetupPath string webAdminOIDCLoginPath string webAdminLoginPath string webAdminTwoFactorPath string webAdminTwoFactorRecoveryPath string webLogoutPath string webUsersPath string webUserPath string webConnectionsPath string webFoldersPath string webFolderPath string webGroupsPath string webGroupPath string webStatusPath string webAdminsPath string webAdminPath string webMaintenancePath string webBackupPath string webRestorePath string webScanVFolderPath string webQuotaScanPath string webAdminProfilePath string webAdminMFAPath string webAdminEventRulesPath string webAdminEventRulePath string webAdminEventActionsPath string webAdminEventActionPath string webAdminRolesPath string webAdminRolePath string webAdminTOTPGeneratePath string webAdminTOTPValidatePath string webAdminTOTPSavePath string webAdminRecoveryCodesPath string webChangeAdminPwdPath string webAdminForgotPwdPath string webAdminResetPwdPath string webTemplateUser string webTemplateFolder string webDefenderPath string webIPListPath string webIPListsPath string webEventsPath string webEventsFsSearchPath string webEventsProviderSearchPath string webEventsLogSearchPath string webConfigsPath string webDefenderHostsPath string webClientLoginPath string webClientOIDCLoginPath string webClientTwoFactorPath string webClientTwoFactorRecoveryPath string webClientFilesPath string webClientFilePath string webClientFileActionsPath string webClientSharesPath string webClientSharePath string webClientEditFilePath string webClientDirsPath string webClientDownloadZipPath string webClientProfilePath string webClientPingPath string webChangeClientPwdPath string webClientMFAPath string webClientTOTPGeneratePath string webClientTOTPValidatePath string webClientTOTPSavePath string webClientRecoveryCodesPath string webClientPubSharesPath string webClientLogoutPath string webClientForgotPwdPath string webClientResetPwdPath string webClientViewPDFPath string webClientGetPDFPath string webClientExistPath string webClientTasksPath string webStaticFilesPath string webOpenAPIPath string // max upload size for http clients, 1GB by default maxUploadFileSize = int64(1048576000) hideSupportLink bool installationCode string installationCodeHint string fnInstallationCodeResolver FnInstallationCodeResolver configurationDir string dbBrandingConfig brandingCache ) func init() { updateWebAdminURLs("") updateWebClientURLs("") acme.SetReloadHTTPDCertsFn(ReloadCertificateMgr) common.SetUpdateBrandingFn(dbBrandingConfig.Set) } type brandingCache struct { mu sync.RWMutex configs *dataprovider.BrandingConfigs } func (b *brandingCache) Set(configs *dataprovider.BrandingConfigs) { b.mu.Lock() defer b.mu.Unlock() b.configs = configs } func (b *brandingCache) getWebAdminLogo() []byte { b.mu.RLock() defer b.mu.RUnlock() return b.configs.WebAdmin.Logo } func (b *brandingCache) getWebAdminFavicon() []byte { b.mu.RLock() defer b.mu.RUnlock() return b.configs.WebAdmin.Favicon } func (b *brandingCache) getWebClientLogo() []byte { b.mu.RLock() defer b.mu.RUnlock() return b.configs.WebClient.Logo } func (b *brandingCache) getWebClientFavicon() []byte { b.mu.RLock() defer b.mu.RUnlock() return b.configs.WebClient.Favicon } func (b *brandingCache) mergeBrandingConfig(branding UIBranding, isWebClient bool) UIBranding { b.mu.RLock() defer b.mu.RUnlock() var urlPrefix string var cfg dataprovider.BrandingConfig if isWebClient { cfg = b.configs.WebClient urlPrefix = "webclient" } else { cfg = b.configs.WebAdmin urlPrefix = "webadmin" } if cfg.Name != "" { branding.Name = cfg.Name } if cfg.ShortName != "" { branding.ShortName = cfg.ShortName } if cfg.DisclaimerName != "" { branding.DisclaimerName = cfg.DisclaimerName } if cfg.DisclaimerURL != "" { branding.DisclaimerPath = cfg.DisclaimerURL } if len(cfg.Logo) > 0 { branding.LogoPath = path.Join("/", "branding", urlPrefix, "logo.png") } if len(cfg.Favicon) > 0 { branding.FaviconPath = path.Join("/", "branding", urlPrefix, "favicon.png") } return branding } // FnInstallationCodeResolver defines a method to get the installation code. // If the installation code cannot be resolved the provided default must be returned type FnInstallationCodeResolver func(defaultInstallationCode string) string // HTTPSProxyHeader defines an HTTPS proxy header as key/value. // For example Key could be "X-Forwarded-Proto" and Value "https" type HTTPSProxyHeader struct { Key string Value string } // SecurityConf allows to add some security related headers to HTTP responses and to restrict allowed hosts type SecurityConf struct { // Set to true to enable the security configurations Enabled bool `json:"enabled" mapstructure:"enabled"` // AllowedHosts is a list of fully qualified domain names that are allowed. // Default is empty list, which allows any and all host names. AllowedHosts []string `json:"allowed_hosts" mapstructure:"allowed_hosts"` // AllowedHostsAreRegex determines if the provided allowed hosts contains valid regular expressions AllowedHostsAreRegex bool `json:"allowed_hosts_are_regex" mapstructure:"allowed_hosts_are_regex"` // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. HostsProxyHeaders []string `json:"hosts_proxy_headers" mapstructure:"hosts_proxy_headers"` // Set to true to redirect HTTP requests to HTTPS HTTPSRedirect bool `json:"https_redirect" mapstructure:"https_redirect"` // HTTPSHost defines the host name that is used to redirect HTTP requests to HTTPS. // Default is "", which indicates to use the same host. HTTPSHost string `json:"https_host" mapstructure:"https_host"` // HTTPSProxyHeaders is a list of header keys with associated values that would indicate a valid https request. HTTPSProxyHeaders []HTTPSProxyHeader `json:"https_proxy_headers" mapstructure:"https_proxy_headers"` // STSSeconds is the max-age of the Strict-Transport-Security header. // Default is 0, which would NOT include the header. STSSeconds int64 `json:"sts_seconds" mapstructure:"sts_seconds"` // If STSIncludeSubdomains is set to true, the "includeSubdomains" will be appended to the // Strict-Transport-Security header. Default is false. STSIncludeSubdomains bool `json:"sts_include_subdomains" mapstructure:"sts_include_subdomains"` // If STSPreload is set to true, the `preload` flag will be appended to the // Strict-Transport-Security header. Default is false. STSPreload bool `json:"sts_preload" mapstructure:"sts_preload"` // If ContentTypeNosniff is true, adds the X-Content-Type-Options header with the value "nosniff". Default is false. ContentTypeNosniff bool `json:"content_type_nosniff" mapstructure:"content_type_nosniff"` // ContentSecurityPolicy allows to set the Content-Security-Policy header value. Default is "". ContentSecurityPolicy string `json:"content_security_policy" mapstructure:"content_security_policy"` // PermissionsPolicy allows to set the Permissions-Policy header value. Default is "". PermissionsPolicy string `json:"permissions_policy" mapstructure:"permissions_policy"` // CrossOriginOpenerPolicy allows to set the Cross-Origin-Opener-Policy header value. Default is "". CrossOriginOpenerPolicy string `json:"cross_origin_opener_policy" mapstructure:"cross_origin_opener_policy"` // CrossOriginResourcePolicy allows to set the Cross-Origin-Resource-Policy header value. Default is "". CrossOriginResourcePolicy string `json:"cross_origin_resource_policy" mapstructure:"cross_origin_resource_policy"` // CrossOriginEmbedderPolicy allows to set the Cross-Origin-Embedder-Policy header value. Default is "". CrossOriginEmbedderPolicy string `json:"cross_origin_embedder_policy" mapstructure:"cross_origin_embedder_policy"` // CacheControl allows to set the Cache-Control header value. CacheControl string `json:"cache_control" mapstructure:"cache_control"` // ReferrerPolicy allows to set the Referrer-Policy header values. ReferrerPolicy string `json:"referrer_policy" mapstructure:"referrer_policy"` proxyHeaders []string } func (s *SecurityConf) updateProxyHeaders() { if !s.Enabled { s.proxyHeaders = nil return } s.proxyHeaders = s.HostsProxyHeaders for _, httpsProxyHeader := range s.HTTPSProxyHeaders { s.proxyHeaders = append(s.proxyHeaders, httpsProxyHeader.Key) } } func (s *SecurityConf) getHTTPSProxyHeaders() map[string]string { headers := make(map[string]string) for _, httpsProxyHeader := range s.HTTPSProxyHeaders { headers[httpsProxyHeader.Key] = httpsProxyHeader.Value } return headers } func (s *SecurityConf) redirectHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !isTLS(r) && !strings.HasPrefix(r.RequestURI, acmeChallengeURI) { url := r.URL url.Scheme = "https" if s.HTTPSHost != "" { url.Host = s.HTTPSHost } else { host := r.Host for _, header := range s.HostsProxyHeaders { if h := r.Header.Get(header); h != "" { host = h break } } url.Host = host } http.Redirect(w, r, url.String(), http.StatusTemporaryRedirect) return } next.ServeHTTP(w, r) }) } // UIBranding defines the supported customizations for the web UIs type UIBranding struct { // Name defines the text to show at the login page and as HTML title Name string `json:"name" mapstructure:"name"` // ShortName defines the name to show next to the logo image ShortName string `json:"short_name" mapstructure:"short_name"` // Path to your logo relative to "static_files_path". // For example, if you create a directory named "branding" inside the static dir and // put the "mylogo.png" file in it, you must set "/branding/mylogo.png" as logo path. LogoPath string `json:"logo_path" mapstructure:"logo_path"` // Path to your favicon relative to "static_files_path" FaviconPath string `json:"favicon_path" mapstructure:"favicon_path"` // DisclaimerName defines the name for the link to your optional disclaimer DisclaimerName string `json:"disclaimer_name" mapstructure:"disclaimer_name"` // Path to the HTML page for your disclaimer relative to "static_files_path" // or an absolute http/https URL. DisclaimerPath string `json:"disclaimer_path" mapstructure:"disclaimer_path"` // Path to custom CSS files, relative to "static_files_path", which replaces // the default CSS files DefaultCSS []string `json:"default_css" mapstructure:"default_css"` // Additional CSS file paths, relative to "static_files_path", to include ExtraCSS []string `json:"extra_css" mapstructure:"extra_css"` DefaultLogoPath string `json:"-" mapstructure:"-"` DefaultFaviconPath string `json:"-" mapstructure:"-"` } func (b *UIBranding) check() { b.DefaultLogoPath = "/img/logo.png" b.DefaultFaviconPath = "/favicon.png" if b.LogoPath != "" { b.LogoPath = util.CleanPath(b.LogoPath) } else { b.LogoPath = b.DefaultLogoPath } if b.FaviconPath != "" { b.FaviconPath = util.CleanPath(b.FaviconPath) } else { b.FaviconPath = b.DefaultFaviconPath } if b.DisclaimerPath != "" { if !strings.HasPrefix(b.DisclaimerPath, "https://") && !strings.HasPrefix(b.DisclaimerPath, "http://") { b.DisclaimerPath = path.Join(webStaticFilesPath, util.CleanPath(b.DisclaimerPath)) } } if len(b.DefaultCSS) > 0 { for idx := range b.DefaultCSS { b.DefaultCSS[idx] = util.CleanPath(b.DefaultCSS[idx]) } } else { b.DefaultCSS = []string{ "/assets/plugins/global/plugins.bundle.css", "/assets/css/style.bundle.css", } } for idx := range b.ExtraCSS { b.ExtraCSS[idx] = util.CleanPath(b.ExtraCSS[idx]) } } // Branding defines the branding-related customizations supported type Branding struct { WebAdmin UIBranding `json:"web_admin" mapstructure:"web_admin"` WebClient UIBranding `json:"web_client" mapstructure:"web_client"` } // WebClientIntegration defines the configuration for an external Web Client integration type WebClientIntegration struct { // Files with these extensions can be sent to the configured URL FileExtensions []string `json:"file_extensions" mapstructure:"file_extensions"` // URL that will receive the files URL string `json:"url" mapstructure:"url"` } // Binding defines the configuration for a network listener type Binding struct { // The address to listen on. A blank value means listen on all available network interfaces. Address string `json:"address" mapstructure:"address"` // The port used for serving requests Port int `json:"port" mapstructure:"port"` // Enable the built-in admin interface. // You have to define TemplatesPath and StaticFilesPath for this to work EnableWebAdmin bool `json:"enable_web_admin" mapstructure:"enable_web_admin"` // Enable the built-in client interface. // You have to define TemplatesPath and StaticFilesPath for this to work EnableWebClient bool `json:"enable_web_client" mapstructure:"enable_web_client"` // Enable REST API EnableRESTAPI bool `json:"enable_rest_api" mapstructure:"enable_rest_api"` // Defines the login methods available for the WebAdmin and WebClient UIs: // // - 0 means any configured method: username/password login form and OIDC, if enabled // - 1 means OIDC for the WebAdmin UI // - 2 means OIDC for the WebClient UI // - 4 means login form for the WebAdmin UI // - 8 means login form for the WebClient UI // // You can combine the values. For example 3 means that you can only login using OIDC on // both WebClient and WebAdmin UI. // Deprecated because it is not extensible, use DisabledLoginMethods EnabledLoginMethods int `json:"enabled_login_methods" mapstructure:"enabled_login_methods"` // Defines the login methods disabled for the WebAdmin and WebClient UIs: // // - 1 means OIDC for the WebAdmin UI // - 2 means OIDC for the WebClient UI // - 4 means login form for the WebAdmin UI // - 8 means login form for the WebClient UI // - 16 means basic auth for admin REST API // - 32 means basic auth for user REST API // - 64 means API key auth for admins // - 128 means API key auth for users // You can combine the values. For example 12 means that you can only login using OIDC on // both WebClient and WebAdmin UI. DisabledLoginMethods int `json:"disabled_login_methods" mapstructure:"disabled_login_methods"` // you also need to provide a certificate for enabling HTTPS EnableHTTPS bool `json:"enable_https" mapstructure:"enable_https"` // Certificate and matching private key for this specific binding, if empty the global // ones will be used, if any CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` // Defines the minimum TLS version. 13 means TLS 1.3, default is TLS 1.2 MinTLSVersion int `json:"min_tls_version" mapstructure:"min_tls_version"` // set to 1 to require client certificate authentication in addition to basic auth. // You need to define at least a certificate authority for this to work ClientAuthType int `json:"client_auth_type" mapstructure:"client_auth_type"` // TLSCipherSuites is a list of supported cipher suites for TLS version 1.2. // If CipherSuites is nil/empty, a default list of secure cipher suites // is used, with a preference order based on hardware performance. // Note that TLS 1.3 ciphersuites are not configurable. // The supported ciphersuites names are defined here: // // https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53 // // any invalid name will be silently ignored. // The order matters, the ciphers listed first will be the preferred ones. TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` // HTTP protocols in preference order. Supported values: http/1.1, h2 Protocols []string `json:"tls_protocols" mapstructure:"tls_protocols"` // Defines whether to use the common proxy protocol configuration or the // binding-specific proxy header configuration. ProxyMode int `json:"proxy_mode" mapstructure:"proxy_mode"` // List of IP addresses and IP ranges allowed to set client IP proxy headers and // X-Forwarded-Proto header. ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"` // Allowed client IP proxy header such as "X-Forwarded-For", "X-Real-IP" ClientIPProxyHeader string `json:"client_ip_proxy_header" mapstructure:"client_ip_proxy_header"` // Some client IP headers such as "X-Forwarded-For" can contain multiple IP address, this setting // define the position to trust starting from the right. For example if we have: // "10.0.0.1,11.0.0.1,12.0.0.1,13.0.0.1" and the depth is 0, SFTPGo will use "13.0.0.1" // as client IP, if depth is 1, "12.0.0.1" will be used and so on ClientIPHeaderDepth int `json:"client_ip_header_depth" mapstructure:"client_ip_header_depth"` // If both web admin and web client are enabled each login page will show a link // to the other one. This setting allows to hide this link: // - 0 login links are displayed on both admin and client login page. This is the default // - 1 the login link to the web client login page is hidden on admin login page // - 2 the login link to the web admin login page is hidden on client login page // The flags can be combined, for example 3 will disable both login links. HideLoginURL int `json:"hide_login_url" mapstructure:"hide_login_url"` // Enable the built-in OpenAPI renderer RenderOpenAPI bool `json:"render_openapi" mapstructure:"render_openapi"` // BaseURL defines the external base URL for generating public links // (currently share access link), bypassing the default browser-based // detection. BaseURL string `json:"base_url" mapstructure:"base_url"` // Languages defines the list of enabled translations for the WebAdmin and WebClient UI. Languages []string `json:"languages" mapstructure:"languages"` // Defining an OIDC configuration the web admin and web client UI will use OpenID to authenticate users. OIDC OIDC `json:"oidc" mapstructure:"oidc"` // Security defines security headers to add to HTTP responses and allows to restrict allowed hosts Security SecurityConf `json:"security" mapstructure:"security"` // Branding defines customizations to suit your brand Branding Branding `json:"branding" mapstructure:"branding"` allowHeadersFrom []func(net.IP) bool } func (b *Binding) checkBranding() { b.Branding.WebAdmin.check() b.Branding.WebClient.check() if b.Branding.WebAdmin.Name == "" { b.Branding.WebAdmin.Name = "SFTPGo WebAdmin" } if b.Branding.WebAdmin.ShortName == "" { b.Branding.WebAdmin.ShortName = "WebAdmin" } if b.Branding.WebClient.Name == "" { b.Branding.WebClient.Name = "SFTPGo WebClient" } if b.Branding.WebClient.ShortName == "" { b.Branding.WebClient.ShortName = "WebClient" } } func (b *Binding) webAdminBranding() UIBranding { return dbBrandingConfig.mergeBrandingConfig(b.Branding.WebAdmin, false) } func (b *Binding) webClientBranding() UIBranding { return dbBrandingConfig.mergeBrandingConfig(b.Branding.WebClient, true) } func (b *Binding) languages() []string { return b.Languages } func (b *Binding) validateBaseURL() error { if b.BaseURL == "" { return nil } u, err := url.ParseRequestURI(b.BaseURL) if err != nil { return err } if u.Scheme != "http" && u.Scheme != "https" { return fmt.Errorf("invalid base URL schema %s", b.BaseURL) } if u.Host == "" { return fmt.Errorf("invalid base URL host %s", b.BaseURL) } b.BaseURL = strings.TrimRight(u.String(), "/") return nil } func (b *Binding) parseAllowedProxy() error { if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 { // unix domain socket b.allowHeadersFrom = []func(net.IP) bool{func(_ net.IP) bool { return true }} return nil } allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed) if err != nil { return err } b.allowHeadersFrom = allowedFuncs return nil } // GetAddress returns the binding address func (b *Binding) GetAddress() string { return fmt.Sprintf("%s:%d", b.Address, b.Port) } // IsValid returns true if the binding is valid func (b *Binding) IsValid() bool { if !b.EnableRESTAPI && !b.EnableWebAdmin && !b.EnableWebClient { return false } if b.Port > 0 { return true } if filepath.IsAbs(b.Address) && runtime.GOOS != osWindows { return true } return false } func (b *Binding) check() error { if err := b.parseAllowedProxy(); err != nil { return err } if err := b.validateBaseURL(); err != nil { return err } b.checkBranding() b.Security.updateProxyHeaders() return nil } func (b *Binding) isWebAdminOIDCLoginDisabled() bool { if b.EnableWebAdmin { return b.DisabledLoginMethods&1 != 0 } return false } func (b *Binding) isWebClientOIDCLoginDisabled() bool { if b.EnableWebClient { return b.DisabledLoginMethods&2 != 0 } return false } func (b *Binding) isWebAdminLoginFormDisabled() bool { if b.EnableWebAdmin { return b.DisabledLoginMethods&4 != 0 } return false } func (b *Binding) isWebClientLoginFormDisabled() bool { if b.EnableWebClient { return b.DisabledLoginMethods&8 != 0 } return false } func (b *Binding) isAdminTokenEndpointDisabled() bool { return b.DisabledLoginMethods&16 != 0 } func (b *Binding) isUserTokenEndpointDisabled() bool { return b.DisabledLoginMethods&32 != 0 } func (b *Binding) isAdminAPIKeyAuthDisabled() bool { return b.DisabledLoginMethods&64 != 0 } func (b *Binding) isUserAPIKeyAuthDisabled() bool { return b.DisabledLoginMethods&128 != 0 } func (b *Binding) hasLoginForAPI() bool { return !b.isAdminTokenEndpointDisabled() || !b.isUserTokenEndpointDisabled() || !b.isAdminAPIKeyAuthDisabled() || !b.isUserAPIKeyAuthDisabled() } // convertLoginMethods checks if the deprecated EnabledLoginMethods is set and // convert the value to DisabledLoginMethods. func (b *Binding) convertLoginMethods() { if b.DisabledLoginMethods > 0 || b.EnabledLoginMethods == 0 { // DisabledLoginMethods already in use or EnabledLoginMethods not set. return } if b.EnabledLoginMethods&1 == 0 { b.DisabledLoginMethods++ } if b.EnabledLoginMethods&2 == 0 { b.DisabledLoginMethods += 2 } if b.EnabledLoginMethods&4 == 0 { b.DisabledLoginMethods += 4 } if b.EnabledLoginMethods&8 == 0 { b.DisabledLoginMethods += 8 } } func (b *Binding) checkLoginMethods() error { b.convertLoginMethods() if b.isWebAdminLoginFormDisabled() && b.isWebAdminOIDCLoginDisabled() { return errors.New("no login method available for WebAdmin UI") } if !b.isWebAdminOIDCLoginDisabled() { if b.isWebAdminLoginFormDisabled() && !b.OIDC.hasRoles() { return errors.New("no login method available for WebAdmin UI") } } if b.isWebClientLoginFormDisabled() && b.isWebClientOIDCLoginDisabled() { return errors.New("no login method available for WebClient UI") } if !b.isWebClientOIDCLoginDisabled() { if b.isWebClientLoginFormDisabled() && !b.OIDC.isEnabled() { return errors.New("no login method available for WebClient UI") } } if b.EnableRESTAPI && !b.hasLoginForAPI() { return errors.New("no login method available for REST API") } return nil } func (b *Binding) showAdminLoginURL() bool { if !b.EnableWebAdmin { return false } if b.HideLoginURL&2 != 0 { return false } return true } func (b *Binding) showClientLoginURL() bool { if !b.EnableWebClient { return false } if b.HideLoginURL&1 != 0 { return false } return true } func (b *Binding) isMutualTLSEnabled() bool { return b.ClientAuthType == 1 } func (b *Binding) listenerWrapper() func(net.Listener) (net.Listener, error) { if b.ProxyMode == 1 { return common.Config.GetProxyListener } return nil } type defenderStatus struct { IsActive bool `json:"is_active"` } type allowListStatus struct { IsActive bool `json:"is_active"` } type rateLimiters struct { IsActive bool `json:"is_active"` Protocols []string `json:"protocols"` } // GetProtocolsAsString returns the enabled protocols as comma separated string func (r *rateLimiters) GetProtocolsAsString() string { return strings.Join(r.Protocols, ", ") } // ServicesStatus keep the state of the running services type ServicesStatus struct { SSH sftpd.ServiceStatus `json:"ssh"` FTP ftpd.ServiceStatus `json:"ftp"` WebDAV webdavd.ServiceStatus `json:"webdav"` DataProvider dataprovider.ProviderStatus `json:"data_provider"` Defender defenderStatus `json:"defender"` MFA mfa.ServiceStatus `json:"mfa"` AllowList allowListStatus `json:"allow_list"` RateLimiters rateLimiters `json:"rate_limiters"` } // SetupConfig defines the configuration parameters for the initial web admin setup type SetupConfig struct { // Installation code to require when creating the first admin account. // As for the other configurations, this value is read at SFTPGo startup and not at runtime // even if set using an environment variable. // This is not a license key or similar, the purpose here is to prevent anyone who can access // to the initial setup screen from creating an admin user InstallationCode string `json:"installation_code" mapstructure:"installation_code"` // Description for the installation code input field InstallationCodeHint string `json:"installation_code_hint" mapstructure:"installation_code_hint"` } // CorsConfig defines the CORS configuration type CorsConfig struct { AllowedOrigins []string `json:"allowed_origins" mapstructure:"allowed_origins"` AllowedMethods []string `json:"allowed_methods" mapstructure:"allowed_methods"` AllowedHeaders []string `json:"allowed_headers" mapstructure:"allowed_headers"` ExposedHeaders []string `json:"exposed_headers" mapstructure:"exposed_headers"` AllowCredentials bool `json:"allow_credentials" mapstructure:"allow_credentials"` Enabled bool `json:"enabled" mapstructure:"enabled"` MaxAge int `json:"max_age" mapstructure:"max_age"` OptionsPassthrough bool `json:"options_passthrough" mapstructure:"options_passthrough"` OptionsSuccessStatus int `json:"options_success_status" mapstructure:"options_success_status"` AllowPrivateNetwork bool `json:"allow_private_network" mapstructure:"allow_private_network"` } // Conf httpd daemon configuration type Conf struct { // Addresses and ports to bind to Bindings []Binding `json:"bindings" mapstructure:"bindings"` // Path to the HTML web templates. This can be an absolute path or a path relative to the config dir TemplatesPath string `json:"templates_path" mapstructure:"templates_path"` // Path to the static files for the web interface. This can be an absolute path or a path relative to the config dir. // If both TemplatesPath and StaticFilesPath are empty the built-in web interface will be disabled StaticFilesPath string `json:"static_files_path" mapstructure:"static_files_path"` // Path to the backup directory. This can be an absolute path or a path relative to the config dir //BackupsPath string `json:"backups_path" mapstructure:"backups_path"` // Path to the directory that contains the OpenAPI schema and the default renderer. // This can be an absolute path or a path relative to the config dir OpenAPIPath string `json:"openapi_path" mapstructure:"openapi_path"` // Defines a base URL for the web admin and client interfaces. If empty web admin and client resources will // be available at the root ("/") URI. If defined it must be an absolute URI or it will be ignored. WebRoot string `json:"web_root" mapstructure:"web_root"` // If files containing a certificate and matching private key for the server are provided you can enable // HTTPS connections for the configured bindings. // Certificate and key files can be reloaded on demand sending a "SIGHUP" signal on Unix based systems and a // "paramchange" request to the running service on Windows. CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` // CACertificates defines the set of root certificate authorities to be used to verify client certificates. CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"` // CARevocationLists defines a set a revocation lists, one for each root CA, to be used to check // if a client certificate has been revoked CARevocationLists []string `json:"ca_revocation_lists" mapstructure:"ca_revocation_lists"` // SigningPassphrase defines the passphrase to use to derive the signing key for JWT and CSRF tokens. // If empty a random signing key will be generated each time SFTPGo starts. If you set a // signing passphrase you should consider rotating it periodically for added security SigningPassphrase string `json:"signing_passphrase" mapstructure:"signing_passphrase"` SigningPassphraseFile string `json:"signing_passphrase_file" mapstructure:"signing_passphrase_file"` // TokenValidation allows to define how to validate JWT tokens, cookies and CSRF tokens. // By default all the available security checks are enabled. Set to 1 to disable the requirement // that a token must be used by the same IP for which it was issued. TokenValidation int `json:"token_validation" mapstructure:"token_validation"` // CookieLifetime defines the duration of cookies for WebAdmin and WebClient CookieLifetime int `json:"cookie_lifetime" mapstructure:"cookie_lifetime"` // ShareCookieLifetime defines the duration of cookies for public shares ShareCookieLifetime int `json:"share_cookie_lifetime" mapstructure:"share_cookie_lifetime"` // JWTLifetime defines the duration of JWT tokens used in REST API JWTLifetime int `json:"jwt_lifetime" mapstructure:"jwt_lifetime"` // MaxUploadFileSize Defines the maximum request body size, in bytes, for Web Client/API HTTP upload requests. // 0 means no limit MaxUploadFileSize int64 `json:"max_upload_file_size" mapstructure:"max_upload_file_size"` // CORS configuration Cors CorsConfig `json:"cors" mapstructure:"cors"` // Initial setup configuration Setup SetupConfig `json:"setup" mapstructure:"setup"` // If enabled, the link to the sponsors section will not appear on the setup screen page HideSupportLink bool `json:"hide_support_link" mapstructure:"hide_support_link"` acmeDomain string } type apiResponse struct { Error string `json:"error,omitempty"` Message string `json:"message"` } // ShouldBind returns true if there is at least a valid binding func (c *Conf) ShouldBind() bool { for _, binding := range c.Bindings { if binding.IsValid() { return true } } return false } func (c *Conf) isWebAdminEnabled() bool { for _, binding := range c.Bindings { if binding.EnableWebAdmin { return true } } return false } func (c *Conf) isWebClientEnabled() bool { for _, binding := range c.Bindings { if binding.EnableWebClient { return true } } return false } func (c *Conf) checkRequiredDirs(staticFilesPath, templatesPath string) error { if (c.isWebAdminEnabled() || c.isWebClientEnabled()) && (staticFilesPath == "" || templatesPath == "") { return fmt.Errorf("required directory is invalid, static file path: %q template path: %q", staticFilesPath, templatesPath) } return nil } func (c *Conf) getRedacted() Conf { redacted := "[redacted]" conf := *c if conf.SigningPassphrase != "" { conf.SigningPassphrase = redacted } if conf.Setup.InstallationCode != "" { conf.Setup.InstallationCode = redacted } conf.Bindings = nil for _, binding := range c.Bindings { if binding.OIDC.ClientID != "" { binding.OIDC.ClientID = redacted } if binding.OIDC.ClientSecret != "" { binding.OIDC.ClientSecret = redacted } conf.Bindings = append(conf.Bindings, binding) } return conf } func (c *Conf) getKeyPairs(configDir string) []common.TLSKeyPair { var keyPairs []common.TLSKeyPair for _, binding := range c.Bindings { certificateFile := getConfigPath(binding.CertificateFile, configDir) certificateKeyFile := getConfigPath(binding.CertificateKeyFile, configDir) if certificateFile != "" && certificateKeyFile != "" { keyPairs = append(keyPairs, common.TLSKeyPair{ Cert: certificateFile, Key: certificateKeyFile, ID: binding.GetAddress(), }) } } var certificateFile, certificateKeyFile string if c.acmeDomain != "" { certificateFile, certificateKeyFile = util.GetACMECertificateKeyPair(c.acmeDomain) } else { certificateFile = getConfigPath(c.CertificateFile, configDir) certificateKeyFile = getConfigPath(c.CertificateKeyFile, configDir) } if certificateFile != "" && certificateKeyFile != "" { keyPairs = append(keyPairs, common.TLSKeyPair{ Cert: certificateFile, Key: certificateKeyFile, ID: common.DefaultTLSKeyPaidID, }) } return keyPairs } func (c *Conf) setTokenValidationMode() { tokenValidationMode = c.TokenValidation } func (c *Conf) loadFromProvider() error { configs, err := dataprovider.GetConfigs() if err != nil { return fmt.Errorf("unable to load config from provider: %w", err) } configs.SetNilsToEmpty() dbBrandingConfig.Set(configs.Branding) if configs.ACME.Domain == "" || !configs.ACME.HasProtocol(common.ProtocolHTTP) { return nil } crt, key := util.GetACMECertificateKeyPair(configs.ACME.Domain) if crt != "" && key != "" { if _, err := os.Stat(crt); err != nil { logger.Error(logSender, "", "unable to load acme cert file %q: %v", crt, err) return nil } if _, err := os.Stat(key); err != nil { logger.Error(logSender, "", "unable to load acme key file %q: %v", key, err) return nil } for idx := range c.Bindings { if c.Bindings[idx].Security.Enabled && c.Bindings[idx].Security.HTTPSRedirect { continue } c.Bindings[idx].EnableHTTPS = true } c.acmeDomain = configs.ACME.Domain logger.Info(logSender, "", "acme domain set to %q", c.acmeDomain) return nil } return nil } func (c *Conf) loadTemplates(templatesPath string) { if c.isWebAdminEnabled() { updateWebAdminURLs(c.WebRoot) loadAdminTemplates(templatesPath) } else { logger.Info(logSender, "", "built-in web admin interface disabled") } if c.isWebClientEnabled() { updateWebClientURLs(c.WebRoot) loadClientTemplates(templatesPath) } else { logger.Info(logSender, "", "built-in web client interface disabled") } } // Initialize configures and starts the HTTP server func (c *Conf) Initialize(configDir string, isShared int) error { if err := c.loadFromProvider(); err != nil { return err } logger.Info(logSender, "", "initializing HTTP server with config %+v", c.getRedacted()) configurationDir = configDir invalidatedJWTTokens = newTokenManager(isShared) resetCodesMgr = newResetCodeManager(isShared) oidcMgr = newOIDCManager(isShared) oauth2Mgr = newOAuth2Manager(isShared) webTaskMgr = newWebTaskManager(isShared) staticFilesPath := util.FindSharedDataPath(c.StaticFilesPath, configDir) templatesPath := util.FindSharedDataPath(c.TemplatesPath, configDir) openAPIPath := util.FindSharedDataPath(c.OpenAPIPath, configDir) if err := c.checkRequiredDirs(staticFilesPath, templatesPath); err != nil { return err } c.loadTemplates(templatesPath) keyPairs := c.getKeyPairs(configDir) if len(keyPairs) > 0 { mgr, err := common.NewCertManager(keyPairs, configDir, logSender) if err != nil { return err } mgr.SetCACertificates(c.CACertificates) if err := mgr.LoadRootCAs(); err != nil { return err } mgr.SetCARevocationLists(c.CARevocationLists) if err := mgr.LoadCRLs(); err != nil { return err } certMgr = mgr } if c.SigningPassphraseFile != "" { passphrase, err := util.ReadConfigFromFile(c.SigningPassphraseFile, configDir) if err != nil { return err } c.SigningPassphrase = passphrase } hideSupportLink = c.HideSupportLink exitChannel := make(chan error, 1) for _, binding := range c.Bindings { if !binding.IsValid() { continue } if err := binding.check(); err != nil { return err } go func(b Binding) { if err := b.OIDC.initialize(); err != nil { exitChannel <- err return } if err := b.checkLoginMethods(); err != nil { exitChannel <- err return } server := newHttpdServer(b, staticFilesPath, c.SigningPassphrase, c.Cors, openAPIPath) server.setShared(isShared) exitChannel <- server.listenAndServe() }(binding) } maxUploadFileSize = c.MaxUploadFileSize installationCode = c.Setup.InstallationCode installationCodeHint = c.Setup.InstallationCodeHint updateTokensDuration(c.JWTLifetime, c.CookieLifetime, c.ShareCookieLifetime) startCleanupTicker(10 * time.Minute) c.setTokenValidationMode() return <-exitChannel } func isWebRequest(r *http.Request) bool { return strings.HasPrefix(r.RequestURI, webBasePath+"/") } func isWebClientRequest(r *http.Request) bool { return strings.HasPrefix(r.RequestURI, webBaseClientPath+"/") } // ReloadCertificateMgr reloads the certificate manager func ReloadCertificateMgr() error { if certMgr != nil { return certMgr.Reload() } return nil } func getConfigPath(name, configDir string) string { if !util.IsFileInputValid(name) { return "" } if name != "" && !filepath.IsAbs(name) { return filepath.Join(configDir, name) } return name } func getServicesStatus() *ServicesStatus { rtlEnabled, rtlProtocols := common.Config.GetRateLimitersStatus() status := &ServicesStatus{ SSH: sftpd.GetStatus(), FTP: ftpd.GetStatus(), WebDAV: webdavd.GetStatus(), DataProvider: dataprovider.GetProviderStatus(), Defender: defenderStatus{ IsActive: common.Config.DefenderConfig.Enabled, }, MFA: mfa.GetStatus(), AllowList: allowListStatus{ IsActive: common.Config.IsAllowListEnabled(), }, RateLimiters: rateLimiters{ IsActive: rtlEnabled, Protocols: rtlProtocols, }, } return status } func fileServer(r chi.Router, path string, root http.FileSystem, disableDirectoryIndex bool) { if path != "/" && path[len(path)-1] != '/' { r.Get(path, http.RedirectHandler(path+"/", http.StatusMovedPermanently).ServeHTTP) path += "/" } path += "*" r.Get(path, func(w http.ResponseWriter, r *http.Request) { rctx := chi.RouteContext(r.Context()) pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") if disableDirectoryIndex { root = neuteredFileSystem{root} } fs := http.StripPrefix(pathPrefix, http.FileServer(root)) fs.ServeHTTP(w, r) }) } func updateWebClientURLs(baseURL string) { if !path.IsAbs(baseURL) { baseURL = "/" } webRootPath = path.Join(baseURL, webRootPathDefault) webBasePath = path.Join(baseURL, webBasePathDefault) webBaseClientPath = path.Join(baseURL, webBasePathClientDefault) webOIDCRedirectPath = path.Join(baseURL, webOIDCRedirectPathDefault) webClientLoginPath = path.Join(baseURL, webClientLoginPathDefault) webClientOIDCLoginPath = path.Join(baseURL, webClientOIDCLoginPathDefault) webClientTwoFactorPath = path.Join(baseURL, webClientTwoFactorPathDefault) webClientTwoFactorRecoveryPath = path.Join(baseURL, webClientTwoFactorRecoveryPathDefault) webClientFilesPath = path.Join(baseURL, webClientFilesPathDefault) webClientFilePath = path.Join(baseURL, webClientFilePathDefault) webClientFileActionsPath = path.Join(baseURL, webClientFileActionsPathDefault) webClientSharesPath = path.Join(baseURL, webClientSharesPathDefault) webClientPubSharesPath = path.Join(baseURL, webClientPubSharesPathDefault) webClientSharePath = path.Join(baseURL, webClientSharePathDefault) webClientEditFilePath = path.Join(baseURL, webClientEditFilePathDefault) webClientDirsPath = path.Join(baseURL, webClientDirsPathDefault) webClientDownloadZipPath = path.Join(baseURL, webClientDownloadZipPathDefault) webClientProfilePath = path.Join(baseURL, webClientProfilePathDefault) webClientPingPath = path.Join(baseURL, webClientPingPathDefault) webChangeClientPwdPath = path.Join(baseURL, webChangeClientPwdPathDefault) webClientLogoutPath = path.Join(baseURL, webClientLogoutPathDefault) webClientMFAPath = path.Join(baseURL, webClientMFAPathDefault) webClientTOTPGeneratePath = path.Join(baseURL, webClientTOTPGeneratePathDefault) webClientTOTPValidatePath = path.Join(baseURL, webClientTOTPValidatePathDefault) webClientTOTPSavePath = path.Join(baseURL, webClientTOTPSavePathDefault) webClientRecoveryCodesPath = path.Join(baseURL, webClientRecoveryCodesPathDefault) webClientForgotPwdPath = path.Join(baseURL, webClientForgotPwdPathDefault) webClientResetPwdPath = path.Join(baseURL, webClientResetPwdPathDefault) webClientViewPDFPath = path.Join(baseURL, webClientViewPDFPathDefault) webClientGetPDFPath = path.Join(baseURL, webClientGetPDFPathDefault) webClientExistPath = path.Join(baseURL, webClientExistPathDefault) webClientTasksPath = path.Join(baseURL, webClientTasksPathDefault) webStaticFilesPath = path.Join(baseURL, webStaticFilesPathDefault) webOpenAPIPath = path.Join(baseURL, webOpenAPIPathDefault) } func updateWebAdminURLs(baseURL string) { if !path.IsAbs(baseURL) { baseURL = "/" } webRootPath = path.Join(baseURL, webRootPathDefault) webBasePath = path.Join(baseURL, webBasePathDefault) webBaseAdminPath = path.Join(baseURL, webBasePathAdminDefault) webOIDCRedirectPath = path.Join(baseURL, webOIDCRedirectPathDefault) webOAuth2RedirectPath = path.Join(baseURL, webOAuth2RedirectPathDefault) webOAuth2TokenPath = path.Join(baseURL, webOAuth2TokenPathDefault) webAdminSetupPath = path.Join(baseURL, webAdminSetupPathDefault) webAdminLoginPath = path.Join(baseURL, webAdminLoginPathDefault) webAdminOIDCLoginPath = path.Join(baseURL, webAdminOIDCLoginPathDefault) webAdminTwoFactorPath = path.Join(baseURL, webAdminTwoFactorPathDefault) webAdminTwoFactorRecoveryPath = path.Join(baseURL, webAdminTwoFactorRecoveryPathDefault) webLogoutPath = path.Join(baseURL, webLogoutPathDefault) webUsersPath = path.Join(baseURL, webUsersPathDefault) webUserPath = path.Join(baseURL, webUserPathDefault) webConnectionsPath = path.Join(baseURL, webConnectionsPathDefault) webFoldersPath = path.Join(baseURL, webFoldersPathDefault) webFolderPath = path.Join(baseURL, webFolderPathDefault) webGroupsPath = path.Join(baseURL, webGroupsPathDefault) webGroupPath = path.Join(baseURL, webGroupPathDefault) webStatusPath = path.Join(baseURL, webStatusPathDefault) webAdminsPath = path.Join(baseURL, webAdminsPathDefault) webAdminPath = path.Join(baseURL, webAdminPathDefault) webMaintenancePath = path.Join(baseURL, webMaintenancePathDefault) webBackupPath = path.Join(baseURL, webBackupPathDefault) webRestorePath = path.Join(baseURL, webRestorePathDefault) webScanVFolderPath = path.Join(baseURL, webScanVFolderPathDefault) webQuotaScanPath = path.Join(baseURL, webQuotaScanPathDefault) webChangeAdminPwdPath = path.Join(baseURL, webChangeAdminPwdPathDefault) webAdminForgotPwdPath = path.Join(baseURL, webAdminForgotPwdPathDefault) webAdminResetPwdPath = path.Join(baseURL, webAdminResetPwdPathDefault) webAdminProfilePath = path.Join(baseURL, webAdminProfilePathDefault) webAdminMFAPath = path.Join(baseURL, webAdminMFAPathDefault) webAdminEventRulesPath = path.Join(baseURL, webAdminEventRulesPathDefault) webAdminEventRulePath = path.Join(baseURL, webAdminEventRulePathDefault) webAdminEventActionsPath = path.Join(baseURL, webAdminEventActionsPathDefault) webAdminEventActionPath = path.Join(baseURL, webAdminEventActionPathDefault) webAdminRolesPath = path.Join(baseURL, webAdminRolesPathDefault) webAdminRolePath = path.Join(baseURL, webAdminRolePathDefault) webAdminTOTPGeneratePath = path.Join(baseURL, webAdminTOTPGeneratePathDefault) webAdminTOTPValidatePath = path.Join(baseURL, webAdminTOTPValidatePathDefault) webAdminTOTPSavePath = path.Join(baseURL, webAdminTOTPSavePathDefault) webAdminRecoveryCodesPath = path.Join(baseURL, webAdminRecoveryCodesPathDefault) webTemplateUser = path.Join(baseURL, webTemplateUserDefault) webTemplateFolder = path.Join(baseURL, webTemplateFolderDefault) webDefenderHostsPath = path.Join(baseURL, webDefenderHostsPathDefault) webDefenderPath = path.Join(baseURL, webDefenderPathDefault) webIPListPath = path.Join(baseURL, webIPListPathDefault) webIPListsPath = path.Join(baseURL, webIPListsPathDefault) webEventsPath = path.Join(baseURL, webEventsPathDefault) webEventsFsSearchPath = path.Join(baseURL, webEventsFsSearchPathDefault) webEventsProviderSearchPath = path.Join(baseURL, webEventsProviderSearchPathDefault) webEventsLogSearchPath = path.Join(baseURL, webEventsLogSearchPathDefault) webConfigsPath = path.Join(baseURL, webConfigsPathDefault) webStaticFilesPath = path.Join(baseURL, webStaticFilesPathDefault) webOpenAPIPath = path.Join(baseURL, webOpenAPIPathDefault) } // GetHTTPRouter returns an HTTP handler suitable to use for test cases func GetHTTPRouter(b Binding) (http.Handler, error) { server := newHttpdServer(b, filepath.Join("..", "..", "static"), "", CorsConfig{}, filepath.Join("..", "..", "openapi")) if err := server.initializeRouter(); err != nil { return nil, err } return server.router, nil } // the ticker cannot be started/stopped from multiple goroutines func startCleanupTicker(duration time.Duration) { stopCleanupTicker() cleanupTicker = time.NewTicker(duration) cleanupDone = make(chan bool) go func() { counter := int64(0) for { select { case <-cleanupDone: return case <-cleanupTicker.C: counter++ invalidatedJWTTokens.Cleanup() resetCodesMgr.Cleanup() webTaskMgr.Cleanup() if counter%2 == 0 { oidcMgr.cleanup() oauth2Mgr.cleanup() } } } }() } func stopCleanupTicker() { if cleanupTicker != nil { cleanupTicker.Stop() cleanupDone <- true cleanupTicker = nil } } func getSigningKey(signingPassphrase string) []byte { var key []byte if signingPassphrase != "" { key = []byte(signingPassphrase) } else { key = util.GenerateRandomBytes(32) } sk := sha256.Sum256(key) return sk[:] } // SetInstallationCodeResolver sets a function to call to resolve the installation code func SetInstallationCodeResolver(fn FnInstallationCodeResolver) { fnInstallationCodeResolver = fn } func resolveInstallationCode() string { if fnInstallationCodeResolver != nil { return fnInstallationCodeResolver(installationCode) } return installationCode } type neuteredFileSystem struct { fs http.FileSystem } func (nfs neuteredFileSystem) Open(name string) (http.File, error) { f, err := nfs.fs.Open(name) if err != nil { return nil, err } s, err := f.Stat() if err != nil { return nil, err } if s.IsDir() { index := path.Join(name, "index.html") if _, err := nfs.fs.Open(index); err != nil { defer f.Close() return nil, err } } return f, nil } ================================================ FILE: internal/httpd/httpd_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd_test import ( "bytes" "crypto/rand" "encoding/json" "errors" "fmt" "image" "image/color" "image/png" "io" "io/fs" "math" "mime/multipart" "net" "net/http" "net/http/httptest" "net/url" "os" "path" "path/filepath" "regexp" "runtime" "slices" "strconv" "strings" "sync" "testing" "time" "github.com/go-chi/render" _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v5/stdlib" "github.com/lithammer/shortuuid/v4" _ "github.com/mattn/go-sqlite3" "github.com/mhale/smtpd" "github.com/pkg/sftp" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "golang.org/x/crypto/ssh" "golang.org/x/net/html" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/httpdtest" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( defaultUsername = "test_user" defaultPassword = "test_password" testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" testPubKey1 = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCd60+/j+y8f0tLftihWV1YN9RSahMI9btQMDIMqts/jeNbD8jgoogM3nhF7KxfcaMKURuD47KC4Ey6iAJUJ0sWkSNNxOcIYuvA+5MlspfZDsa8Ag76Fe1vyz72WeHMHMeh/hwFo2TeIeIXg480T1VI6mzfDrVp2GzUx0SS0dMsQBjftXkuVR8YOiOwMCAH2a//M1OrvV7d/NBk6kBN0WnuIBb2jKm15PAA7+jQQG7tzwk2HedNH3jeL5GH31xkSRwlBczRK0xsCQXehAlx6cT/e/s44iJcJTHfpPKoSk6UAhPJYe7Z1QnuoawY9P9jQaxpyeImBZxxUEowhjpj2avBxKdRGBVK8R7EL8tSOeLbhdyWe5Mwc1+foEbq9Zz5j5Kd+hn3Wm1UnsGCrXUUUoZp1jnlNl0NakCto+5KmqnT9cHxaY+ix2RLUWAZyVFlRq71OYux1UHJnEJPiEI1/tr4jFBSL46qhQZv/TfpkfVW8FLz0lErfqu0gQEZnNHr3Fc= nicola@p1" defaultTokenAuthUser = "admin" defaultTokenAuthPass = "password" altAdminUsername = "newTestAdmin" altAdminPassword = "password1" csrfFormToken = "_form_token" tokenPath = "/api/v2/token" userTokenPath = "/api/v2/user/token" userLogoutPath = "/api/v2/user/logout" userPath = "/api/v2/users" adminPath = "/api/v2/admins" adminPwdPath = "/api/v2/admin/changepwd" folderPath = "/api/v2/folders" groupPath = "/api/v2/groups" activeConnectionsPath = "/api/v2/connections" serverStatusPath = "/api/v2/status" quotasBasePath = "/api/v2/quotas" quotaScanPath = "/api/v2/quotas/users/scans" quotaScanVFolderPath = "/api/v2/quotas/folders/scans" defenderHosts = "/api/v2/defender/hosts" versionPath = "/api/v2/version" logoutPath = "/api/v2/logout" userPwdPath = "/api/v2/user/changepwd" userDirsPath = "/api/v2/user/dirs" userFilesPath = "/api/v2/user/files" userFileActionsPath = "/api/v2/user/file-actions" userStreamZipPath = "/api/v2/user/streamzip" userUploadFilePath = "/api/v2/user/files/upload" userFilesDirsMetadataPath = "/api/v2/user/files/metadata" apiKeysPath = "/api/v2/apikeys" adminTOTPConfigsPath = "/api/v2/admin/totp/configs" adminTOTPGeneratePath = "/api/v2/admin/totp/generate" adminTOTPValidatePath = "/api/v2/admin/totp/validate" adminTOTPSavePath = "/api/v2/admin/totp/save" admin2FARecoveryCodesPath = "/api/v2/admin/2fa/recoverycodes" adminProfilePath = "/api/v2/admin/profile" userTOTPConfigsPath = "/api/v2/user/totp/configs" userTOTPGeneratePath = "/api/v2/user/totp/generate" userTOTPValidatePath = "/api/v2/user/totp/validate" userTOTPSavePath = "/api/v2/user/totp/save" user2FARecoveryCodesPath = "/api/v2/user/2fa/recoverycodes" userProfilePath = "/api/v2/user/profile" userSharesPath = "/api/v2/user/shares" fsEventsPath = "/api/v2/events/fs" providerEventsPath = "/api/v2/events/provider" logEventsPath = "/api/v2/events/logs" sharesPath = "/api/v2/shares" eventActionsPath = "/api/v2/eventactions" eventRulesPath = "/api/v2/eventrules" rolesPath = "/api/v2/roles" ipListsPath = "/api/v2/iplists" healthzPath = "/healthz" webBasePath = "/web" webBasePathAdmin = "/web/admin" webAdminSetupPath = "/web/admin/setup" webLoginPath = "/web/admin/login" webLogoutPath = "/web/admin/logout" webUsersPath = "/web/admin/users" webUserPath = "/web/admin/user" webGroupsPath = "/web/admin/groups" webGroupPath = "/web/admin/group" webFoldersPath = "/web/admin/folders" webFolderPath = "/web/admin/folder" webConnectionsPath = "/web/admin/connections" webStatusPath = "/web/admin/status" webAdminsPath = "/web/admin/managers" webAdminPath = "/web/admin/manager" webMaintenancePath = "/web/admin/maintenance" webRestorePath = "/web/admin/restore" webChangeAdminPwdPath = "/web/admin/changepwd" webAdminProfilePath = "/web/admin/profile" webTemplateUser = "/web/admin/template/user" webTemplateFolder = "/web/admin/template/folder" webDefenderPath = "/web/admin/defender" webIPListsPath = "/web/admin/ip-lists" webIPListPath = "/web/admin/ip-list" webAdminTwoFactorPath = "/web/admin/twofactor" webAdminTwoFactorRecoveryPath = "/web/admin/twofactor-recovery" webAdminMFAPath = "/web/admin/mfa" webAdminTOTPSavePath = "/web/admin/totp/save" webAdminForgotPwdPath = "/web/admin/forgot-password" webAdminResetPwdPath = "/web/admin/reset-password" webAdminEventRulesPath = "/web/admin/eventrules" webAdminEventRulePath = "/web/admin/eventrule" webAdminEventActionsPath = "/web/admin/eventactions" webAdminEventActionPath = "/web/admin/eventaction" webAdminRolesPath = "/web/admin/roles" webAdminRolePath = "/web/admin/role" webEventsPath = "/web/admin/events" webConfigsPath = "/web/admin/configs" webOAuth2TokenPath = "/web/admin/oauth2/token" webBasePathClient = "/web/client" webClientLoginPath = "/web/client/login" webClientFilesPath = "/web/client/files" webClientEditFilePath = "/web/client/editfile" webClientDirsPath = "/web/client/dirs" webClientDownloadZipPath = "/web/client/downloadzip" webChangeClientPwdPath = "/web/client/changepwd" webClientProfilePath = "/web/client/profile" webClientPingPath = "/web/client/ping" webClientTwoFactorPath = "/web/client/twofactor" webClientTwoFactorRecoveryPath = "/web/client/twofactor-recovery" webClientLogoutPath = "/web/client/logout" webClientMFAPath = "/web/client/mfa" webClientTOTPSavePath = "/web/client/totp/save" webClientSharesPath = "/web/client/shares" webClientSharePath = "/web/client/share" webClientPubSharesPath = "/web/client/pubshares" webClientForgotPwdPath = "/web/client/forgot-password" webClientResetPwdPath = "/web/client/reset-password" webClientViewPDFPath = "/web/client/viewpdf" webClientGetPDFPath = "/web/client/getpdf" webClientExistPath = "/web/client/exist" webClientTasksPath = "/web/client/tasks" webClientFileMovePath = "/web/client/file-actions/move" webClientFileCopyPath = "/web/client/file-actions/copy" jsonAPISuffix = "/json" httpBaseURL = "http://127.0.0.1:8081" defaultRemoteAddr = "127.0.0.1:1234" sftpServerAddr = "127.0.0.1:8022" smtpServerAddr = "127.0.0.1:3525" httpsCert = `-----BEGIN CERTIFICATE----- MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM 3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY /8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI dV4vKmHUzwK/eIx+8Ay3neE= -----END CERTIFICATE-----` httpsKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= -----END EC PRIVATE KEY-----` sftpPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW QyNTUxOQAAACB+RB4yNTZz9mHOkawwUibNdemijVV3ErMeLxWUBlCN/gAAAJA7DjpfOw46 XwAAAAtzc2gtZWQyNTUxOQAAACB+RB4yNTZz9mHOkawwUibNdemijVV3ErMeLxWUBlCN/g AAAEA0E24gi8ab/XRSvJ85TGZJMe6HVmwxSG4ExPfTMwwe2n5EHjI1NnP2Yc6RrDBSJs11 6aKNVXcSsx4vFZQGUI3+AAAACW5pY29sYUBwMQECAwQ= -----END OPENSSH PRIVATE KEY-----` sftpPkeyFingerprint = "SHA256:QVQ06XHZZbYZzqfrsZcf3Yozy2WTnqQPeLOkcJCdbP0" // password protected private key testPrivateKeyPwd = `-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABAvfwQQcs +PyMsCLTNFcKiQAAAAEAAAAAEAAAAzAAAAC3NzaC1lZDI1NTE5AAAAILqltfCL7IPuIQ2q +8w23flfgskjIlKViEwMfjJR4mrbAAAAkHp5xgG8J1XW90M/fT59ZUQht8sZzzP17rEKlX waYKvLzDxkPK6LFIYs55W1EX1eVt/2Maq+zQ7k2SOUmhPNknsUOlPV2gytX3uIYvXF7u2F FTBIJuzZ+UQ14wFbraunliE9yye9DajVG1kz2cz2wVgXUbee+gp5NyFVvln+TcTxXwMsWD qwlk5iw/jQekxThg== -----END OPENSSH PRIVATE KEY----- ` testPubKeyPwd = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILqltfCL7IPuIQ2q+8w23flfgskjIlKViEwMfjJR4mrb" privateKeyPwd = "password" rsa1024PrivKey = `-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn NhAAAAAwEAAQAAAIEAxgrZ84gJyU7Qz8JbYuYh0fgTN29h4qVkqDkEE0lWZe7L4QRcQHrB vycJO5vjfitY5JTojV3nbDNHN6XGVX8QNurwXmxv0EmEbqPoNO/rTf1t7qqwMBBAfSJJ5H TXsO37vqcWSOt1Ki5yjRm232UfPo3AYXaZdOKDWKpzI12FfqkAAAIAondFqKJ3RagAAAAH c3NoLXJzYQAAAIEAxgrZ84gJyU7Qz8JbYuYh0fgTN29h4qVkqDkEE0lWZe7L4QRcQHrBvy cJO5vjfitY5JTojV3nbDNHN6XGVX8QNurwXmxv0EmEbqPoNO/rTf1t7qqwMBBAfSJJ5HTX sO37vqcWSOt1Ki5yjRm232UfPo3AYXaZdOKDWKpzI12FfqkAAAADAQABAAAAgC7V5COG+a GFJTbtJQWnnTn17D2A9upN6RcrnL4e6vLiXY8So+qP3YAicDmLrWpqP/SXDsRX/+ID4oTT jKstiJy5jTvXAozwBbFCvNDk1qifs8p/HKzel3t0172j6gLOa2h9+clJ4BYyCk6ue4f8fV yKTIc9chdJSpeINNY60CJxAAAAQQDhYpGXljD2Xy/CzqRXyoF+iMtOImLlbgQYswTXegk3 7JoCNvwqg8xP+JxGpvUGpX23VWh0nBhzcAKHGlssiYQuAAAAQQDwB6s7s1WIRZ2Jsz8f6l 7/ebpPrAMyKmWkXc7KyvR53zuMkMIdvujM5NkOWh1ON8jtNumArey2dWuGVh+pXbdVAAAA QQDTOAaMcyTfXMH/oSMsp+5obvT/RuewaRLHdBiCy0y1Jw0ykOcOCkswr/btDL26hImaHF SheorO+2We7dnFuUIFAAAACW5pY29sYUBwMQE= -----END OPENSSH PRIVATE KEY-----` rsa1024PubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDGCtnziAnJTtDPwlti5iHR+BM3b2HipWSoOQQTSVZl7svhBFxAesG/Jwk7m+N+K1jklOiNXedsM0c3pcZVfxA26vBebG/QSYRuo+g07+tN/W3uqrAwEEB9IknkdNew7fu+pxZI63UqLnKNGbbfZR8+jcBhdpl04oNYqnMjXYV+qQ==" redactedSecret = "[**redacted**]" osWindows = "windows" oidcMockAddr = "127.0.0.1:11111" ) var ( configDir = filepath.Join(".", "..", "..") defaultPerms = []string{dataprovider.PermAny} homeBasePath string backupsPath string testServer *httptest.Server postConnectPath string preActionPath string lastResetCode string ) type fakeConnection struct { *common.BaseConnection command string } func (c *fakeConnection) Disconnect() error { common.Connections.Remove(c.GetID()) return nil } func (c *fakeConnection) GetClientVersion() string { return "" } func (c *fakeConnection) GetCommand() string { return c.command } func (c *fakeConnection) GetLocalAddress() string { return "" } func (c *fakeConnection) GetRemoteAddress() string { return "" } type generateTOTPRequest struct { ConfigName string `json:"config_name"` } type generateTOTPResponse struct { ConfigName string `json:"config_name"` Issuer string `json:"issuer"` Secret string `json:"secret"` QRCode []byte `json:"qr_code"` } type validateTOTPRequest struct { ConfigName string `json:"config_name"` Passcode string `json:"passcode"` Secret string `json:"secret"` } type recoveryCode struct { Code string `json:"code"` Used bool `json:"used"` } func TestMain(m *testing.M) { //nolint:gocyclo homeBasePath = os.TempDir() logfilePath := filepath.Join(configDir, "sftpgo_api_test.log") logger.InitLogger(logfilePath, 5, 1, 28, false, false, zerolog.DebugLevel) os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2") os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") os.Setenv("SFTPGO_DATA_PROVIDER__NAMING_RULES", "0") os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") os.Setenv("SFTPGO_HTTPD__MAX_UPLOAD_FILE_SIZE", "1048576000") err := config.LoadConfig(configDir, "") if err != nil { logger.WarnToConsole("error loading configuration: %v", err) os.Exit(1) } wdPath, err := os.Getwd() if err != nil { logger.WarnToConsole("error getting exe path: %v", err) os.Exit(1) } pluginsConfig := []plugin.Config{ { Type: "eventsearcher", Cmd: filepath.Join(wdPath, "..", "..", "tests", "eventsearcher", "eventsearcher"), AutoMTLS: true, }, } if runtime.GOOS == osWindows { pluginsConfig[0].Cmd += ".exe" } providerConf := config.GetProviderConf() logger.InfoToConsole("Starting HTTPD tests, provider: %v", providerConf.Driver) backupsPath = filepath.Join(os.TempDir(), "test_backups") providerConf.BackupsPath = backupsPath err = os.MkdirAll(backupsPath, os.ModePerm) if err != nil { logger.ErrorToConsole("error creating backups path: %v", err) os.Exit(1) } kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing kms: %v", err) os.Exit(1) } err = plugin.Initialize(pluginsConfig, "debug") if err != nil { logger.ErrorToConsole("error initializing plugin: %v", err) os.Exit(1) } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing MFA: %v", err) os.Exit(1) } err = dataprovider.Initialize(providerConf, configDir, true) if err != nil { logger.WarnToConsole("error initializing data provider: %v", err) os.Exit(1) } err = common.Initialize(config.GetCommonConfig(), 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) } postConnectPath = filepath.Join(homeBasePath, "postconnect.sh") preActionPath = filepath.Join(homeBasePath, "preaction.sh") httpConfig := config.GetHTTPConfig() httpConfig.RetryMax = 1 httpConfig.Timeout = 5 httpConfig.Initialize(configDir) //nolint:errcheck httpdConf := config.GetHTTPDConfig() httpdConf.Bindings[0].Port = 8081 httpdConf.Bindings[0].Security = httpd.SecurityConf{ Enabled: true, HTTPSProxyHeaders: []httpd.HTTPSProxyHeader{ { Key: "X-Forwarded-Proto", Value: "https", }, }, CacheControl: "private", } httpdtest.SetBaseURL(httpBaseURL) // required to test sftpfs sftpdConf := config.GetSFTPDConfig() sftpdConf.Bindings = []sftpd.Binding{ { Port: 8022, }, } hostKeyPath := filepath.Join(os.TempDir(), "id_rsa") sftpdConf.HostKeys = []string{hostKeyPath} go func() { if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } }() go func() { if err := sftpdConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start SFTP server: %v", err) os.Exit(1) } }() startSMTPServer() startOIDCMockServer() waitTCPListening(httpdConf.Bindings[0].GetAddress()) waitTCPListening(sftpdConf.Bindings[0].GetAddress()) httpd.ReloadCertificateMgr() //nolint:errcheck // now start an https server certPath := filepath.Join(os.TempDir(), "test.crt") keyPath := filepath.Join(os.TempDir(), "test.key") err = os.WriteFile(certPath, []byte(httpsCert), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing HTTPS certificate: %v", err) os.Exit(1) } err = os.WriteFile(keyPath, []byte(httpsKey), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing HTTPS private key: %v", err) os.Exit(1) } httpdConf.Bindings[0].Port = 8443 httpdConf.Bindings[0].EnableHTTPS = true httpdConf.Bindings[0].CertificateFile = certPath httpdConf.Bindings[0].CertificateKeyFile = keyPath httpdConf.Bindings = append(httpdConf.Bindings, httpd.Binding{}) go func() { if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTPS server: %v", err) os.Exit(1) } }() waitTCPListening(httpdConf.Bindings[0].GetAddress()) httpd.ReloadCertificateMgr() //nolint:errcheck handler, err := httpd.GetHTTPRouter(httpdConf.Bindings[0]) if err != nil { logger.ErrorToConsole("unable to get http test handler: %v", err) os.Exit(1) } testServer = httptest.NewServer(handler) defer testServer.Close() exitCode := m.Run() os.Remove(logfilePath) os.RemoveAll(backupsPath) os.Remove(certPath) os.Remove(keyPath) os.Remove(hostKeyPath) os.Remove(hostKeyPath + ".pub") os.Remove(postConnectPath) os.Remove(preActionPath) os.Exit(exitCode) } func TestInitialization(t *testing.T) { isShared := 0 err := config.LoadConfig(configDir, "") assert.NoError(t, err) invalidFile := "invalid file" passphraseFile := filepath.Join(os.TempDir(), util.GenerateUniqueID()) err = os.WriteFile(passphraseFile, []byte("my secret"), 0600) assert.NoError(t, err) defer os.Remove(passphraseFile) httpdConf := config.GetHTTPDConfig() httpdConf.SigningPassphraseFile = invalidFile err = httpdConf.Initialize(configDir, isShared) assert.ErrorIs(t, err, fs.ErrNotExist) httpdConf.SigningPassphraseFile = passphraseFile defaultTemplatesPath := httpdConf.TemplatesPath defaultStaticPath := httpdConf.StaticFilesPath httpdConf.CertificateFile = invalidFile httpdConf.CertificateKeyFile = invalidFile err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf.CertificateFile = "" httpdConf.CertificateKeyFile = "" httpdConf.TemplatesPath = "." err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf = config.GetHTTPDConfig() httpdConf.TemplatesPath = defaultTemplatesPath httpdConf.CertificateFile = invalidFile httpdConf.CertificateKeyFile = invalidFile httpdConf.StaticFilesPath = "" httpdConf.TemplatesPath = "" err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf.StaticFilesPath = defaultStaticPath httpdConf.TemplatesPath = defaultTemplatesPath httpdConf.CertificateFile = filepath.Join(os.TempDir(), "test.crt") httpdConf.CertificateKeyFile = filepath.Join(os.TempDir(), "test.key") httpdConf.CACertificates = append(httpdConf.CACertificates, invalidFile) err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf.CACertificates = nil httpdConf.CARevocationLists = append(httpdConf.CARevocationLists, invalidFile) err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf.CARevocationLists = nil httpdConf.SigningPassphraseFile = passphraseFile httpdConf.Bindings[0].ProxyAllowed = []string{"invalid ip/network"} err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is not a valid IP range") } assert.Equal(t, "my secret", httpdConf.SigningPassphrase) httpdConf.Bindings[0].ProxyAllowed = nil httpdConf.Bindings[0].EnableWebAdmin = false httpdConf.Bindings[0].EnableWebClient = false httpdConf.Bindings[0].Port = 8081 httpdConf.Bindings[0].EnableHTTPS = true httpdConf.Bindings[0].ClientAuthType = 1 httpdConf.TokenValidation = 1 err = httpdConf.Initialize(configDir, 0) assert.Error(t, err) httpdConf.TokenValidation = 0 err = httpdConf.Initialize(configDir, 0) assert.Error(t, err) httpdConf.Bindings[0].OIDC = httpd.OIDC{ ClientID: "123", ClientSecret: "secret", ConfigURL: "http://127.0.0.1:11111", } err = httpdConf.Initialize(configDir, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "oidc") } httpdConf.Bindings[0].OIDC.UsernameField = "preferred_username" err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "oidc") } httpdConf.Bindings[0].OIDC = httpd.OIDC{} httpdConf.Bindings[0].BaseURL = "ftp://127.0.0.1" err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "URL schema") } httpdConf.Bindings[0].BaseURL = "" httpdConf.Bindings[0].EnableWebClient = true httpdConf.Bindings[0].EnableWebAdmin = true httpdConf.Bindings[0].EnableRESTAPI = true httpdConf.Bindings[0].DisabledLoginMethods = 14 err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no login method available for WebAdmin UI") } httpdConf.Bindings[0].DisabledLoginMethods = 13 err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no login method available for WebAdmin UI") } httpdConf.Bindings[0].DisabledLoginMethods = 9 err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no login method available for WebClient UI") } httpdConf.Bindings[0].DisabledLoginMethods = 11 err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no login method available for WebClient UI") } httpdConf.Bindings[0].DisabledLoginMethods = 12 err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no login method available for WebAdmin UI") } httpdConf.Bindings[0].EnableWebAdmin = false err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no login method available for WebClient UI") } httpdConf.Bindings[0].EnableWebClient = false httpdConf.Bindings[0].DisabledLoginMethods = 240 err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "no login method available for REST API") } err = dataprovider.Close() assert.NoError(t, err) err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to load config from provider") } err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestBasicUserHandling(t *testing.T) { u := getTestUser() u.Email = "user@user.com" u.Filters.AdditionalEmails = []string{"email1@user.com", "email2@user.com"} user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) _, resp, err = httpdtest.AddUser(u, http.StatusConflict) assert.NoError(t, err, string(resp)) lastPwdChange := user.LastPasswordChange assert.Greater(t, lastPwdChange, int64(0)) user.MaxSessions = 10 user.QuotaSize = 4096 user.QuotaFiles = 2 user.UploadBandwidth = 128 user.DownloadBandwidth = 64 user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now()) user.AdditionalInfo = "some free text" user.Filters.TLSUsername = sdk.TLSUsernameCN user.Email = "user@example.net" user.OIDCCustomFields = &map[string]any{ "field1": "value1", } user.Filters.WebClient = append(user.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled, sdk.WebClientWriteDisabled) originalUser := user user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, originalUser.ID, user.ID) assert.Equal(t, lastPwdChange, user.LastPasswordChange) user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Nil(t, user.OIDCCustomFields) assert.True(t, user.HasPassword) user.Email = "invalid@email" user.FsConfig.OSConfig = sdk.OSFsConfig{ ReadBufferSize: 1, WriteBufferSize: 2, } _, body, err := httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) assert.Contains(t, string(body), "Validation error: email") user.Email = "" user.Filters.AdditionalEmails = []string{"invalid@email"} _, body, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) assert.Contains(t, string(body), "Validation error: email") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestBasicRoleHandling(t *testing.T) { r := getTestRole() role, resp, err := httpdtest.AddRole(r, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Greater(t, role.CreatedAt, int64(0)) assert.Greater(t, role.UpdatedAt, int64(0)) roleGet, _, err := httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role, roleGet) roles, _, err := httpdtest.GetRoles(0, 0, http.StatusOK) assert.NoError(t, err) if assert.GreaterOrEqual(t, len(roles), 1) { found := false for _, ro := range roles { if ro.Name == r.Name { assert.Equal(t, role, ro) found = true } } assert.True(t, found) } roles, _, err = httpdtest.GetRoles(0, int64(len(roles)), http.StatusOK) assert.NoError(t, err) assert.Len(t, roles, 0) role.Description = "updated desc" _, _, err = httpdtest.UpdateRole(role, http.StatusOK) assert.NoError(t, err) roleGet, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Description, roleGet.Description) _, _, err = httpdtest.GetRoleByName(role.Name+"_", http.StatusNotFound) assert.NoError(t, err) // adding the same role again should fail _, _, err = httpdtest.AddRole(r, http.StatusConflict) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) } func TestRoleRelations(t *testing.T) { r := getTestRole() role, resp, err := httpdtest.AddRole(r, http.StatusCreated) assert.NoError(t, err, string(resp)) a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword a.Role = role.Name _, resp, err = httpdtest.AddAdmin(a, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "a role admin cannot be a super admin") a.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers, dataprovider.PermAdminDeleteUsers, dataprovider.PermAdminViewUsers} a.Role = "missing admin role" _, _, err = httpdtest.AddAdmin(a, http.StatusConflict) assert.NoError(t, err) a.Role = role.Name admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) admin.Role = "invalid role" _, resp, err = httpdtest.UpdateAdmin(admin, http.StatusConflict) assert.NoError(t, err, string(resp)) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Name, admin.Role) resp, err = httpdtest.RemoveRole(role, http.StatusOK) assert.Error(t, err, "removing a referenced role should fail") assert.Contains(t, string(resp), "is referenced") role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) if assert.Len(t, role.Admins, 1) { assert.Equal(t, admin.Username, role.Admins[0]) } u1 := getTestUser() u1.Username = defaultUsername + "1" u1.Role = "missing role" _, _, err = httpdtest.AddUser(u1, http.StatusConflict) assert.NoError(t, err) u1.Role = role.Name user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, role.Name, user1.Role) user1.Role = "missing" _, _, err = httpdtest.UpdateUser(user1, http.StatusConflict, "") assert.NoError(t, err) user1, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Name, user1.Role) role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) if assert.Len(t, role.Admins, 1) { assert.Equal(t, admin.Username, role.Admins[0]) } if assert.Len(t, role.Users, 1) { assert.Equal(t, user1.Username, role.Users[0]) } roles, _, err := httpdtest.GetRoles(0, 0, http.StatusOK) assert.NoError(t, err) for _, r := range roles { if r.Name == role.Name { if assert.Len(t, role.Admins, 1) { assert.Equal(t, admin.Username, role.Admins[0]) } if assert.Len(t, role.Users, 1) { assert.Equal(t, user1.Username, role.Users[0]) } } } u2 := getTestUser() user2, _, err := httpdtest.AddUser(u2, http.StatusCreated) assert.NoError(t, err) // the global admin can list all users users, _, err := httpdtest.GetUsers(0, 0, http.StatusOK) assert.NoError(t, err) assert.GreaterOrEqual(t, len(users), 2) _, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) assert.NoError(t, err) // the role admin can only list users with its role token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) assert.NoError(t, err) httpdtest.SetJWTToken(token) users, _, err = httpdtest.GetUsers(0, 0, http.StatusOK) assert.NoError(t, err) assert.Len(t, users, 1) _, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusNotFound) assert.NoError(t, err) // the role admin can only update/delete users with its role _, _, err = httpdtest.UpdateUser(user1, http.StatusOK, "") assert.NoError(t, err) _, _, err = httpdtest.UpdateUser(user2, http.StatusNotFound, "") assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusNotFound) assert.NoError(t, err) // new users created by a role admin have the same role u3 := getTestUser() u3.Username = defaultUsername + "3" _, _, err = httpdtest.AddUser(u3, http.StatusCreated) if assert.Error(t, err) { assert.Equal(t, err.Error(), "role mismatch") } user3, _, err := httpdtest.GetUserByUsername(u3.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Name, user3.Role) _, err = httpdtest.RemoveUser(user3, http.StatusOK) assert.NoError(t, err) httpdtest.SetJWTToken("") role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Admins, []string{altAdminUsername}) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) user1, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user1.Role) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) } func TestRSAKeyInvalidSize(t *testing.T) { u := getTestUser() u.PublicKeys = append(u.PublicKeys, rsa1024PubKey) _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "invalid size") u = getTestSFTPUser() u.FsConfig.SFTPConfig.Password = nil u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(rsa1024PrivKey) _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "rsa key with size 1024 not accepted") } func TestTLSCert(t *testing.T) { u := getTestUser() u.Filters.TLSCerts = []string{"not a cert"} _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "invalid TLS certificate") u.Filters.TLSCerts = []string{httpsCert} user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) if assert.Len(t, user.Filters.TLSCerts, 1) { assert.Equal(t, httpsCert, user.Filters.TLSCerts[0]) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSortRelatedFolders(t *testing.T) { folder1 := util.GenerateUniqueID() folder2 := util.GenerateUniqueID() folder3 := util.GenerateUniqueID() f1 := vfs.BaseVirtualFolder{ Name: folder1, MappedPath: filepath.Clean(os.TempDir()), } f2 := vfs.BaseVirtualFolder{ Name: folder2, MappedPath: filepath.Clean(os.TempDir()), } f3 := vfs.BaseVirtualFolder{ Name: folder3, MappedPath: filepath.Clean(os.TempDir()), } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: f1, VirtualPath: "/" + folder1, }, { BaseVirtualFolder: f2, VirtualPath: "/" + folder2, }, { BaseVirtualFolder: f3, VirtualPath: "/" + folder3, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user.VirtualFolders, 3) { assert.Equal(t, folder1, user.VirtualFolders[0].Name) assert.Equal(t, folder2, user.VirtualFolders[1].Name) assert.Equal(t, folder3, user.VirtualFolders[2].Name) } // Update user.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: f2, VirtualPath: "/" + folder2, }, { BaseVirtualFolder: f1, VirtualPath: "/" + folder1, }, { BaseVirtualFolder: f3, VirtualPath: "/" + folder3, }, } user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user.VirtualFolders, 3) { assert.Equal(t, folder2, user.VirtualFolders[0].Name) assert.Equal(t, folder1, user.VirtualFolders[1].Name) assert.Equal(t, folder3, user.VirtualFolders[2].Name) } g := getTestGroup() g.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: f1, VirtualPath: "/" + folder1, }, { BaseVirtualFolder: f2, VirtualPath: "/" + folder2, }, { BaseVirtualFolder: f3, VirtualPath: "/" + folder3, }, } group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) if assert.Len(t, group.VirtualFolders, 3) { assert.Equal(t, folder1, group.VirtualFolders[0].Name) assert.Equal(t, folder2, group.VirtualFolders[1].Name) assert.Equal(t, folder3, group.VirtualFolders[2].Name) } group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) group.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: f3, VirtualPath: "/" + folder3, }, { BaseVirtualFolder: f1, VirtualPath: "/" + folder1, }, { BaseVirtualFolder: f2, VirtualPath: "/" + folder2, }, } group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) assert.NoError(t, err) group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) if assert.Len(t, group.VirtualFolders, 3) { assert.Equal(t, folder3, group.VirtualFolders[0].Name) assert.Equal(t, folder1, group.VirtualFolders[1].Name) assert.Equal(t, folder2, group.VirtualFolders[2].Name) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(f1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(f2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(f3, http.StatusOK) assert.NoError(t, err) } func TestSortRelatedGroups(t *testing.T) { name1 := util.GenerateUniqueID() name2 := util.GenerateUniqueID() name3 := util.GenerateUniqueID() g1 := getTestGroup() g1.Name = name1 g2 := getTestGroup() g2.Name = name2 g3 := getTestGroup() g3.Name = name3 group1, _, err := httpdtest.AddGroup(g1, http.StatusCreated) assert.NoError(t, err) group2, _, err := httpdtest.AddGroup(g2, http.StatusCreated) assert.NoError(t, err) group3, _, err := httpdtest.AddGroup(g3, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: name1, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Groups = []sdk.GroupMapping{ { Name: name1, Type: sdk.GroupTypePrimary, }, { Name: name2, Type: sdk.GroupTypeSecondary, }, { Name: name3, Type: sdk.GroupTypeMembership, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user.Groups, 3) { assert.Equal(t, name1, user.Groups[0].Name) assert.Equal(t, name2, user.Groups[1].Name) assert.Equal(t, name3, user.Groups[2].Name) } user.Groups = []sdk.GroupMapping{ { Name: name2, Type: sdk.GroupTypeSecondary, }, { Name: name3, Type: sdk.GroupTypeMembership, }, { Name: name1, Type: sdk.GroupTypePrimary, }, } user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user.Groups, 3) { assert.Equal(t, name2, user.Groups[0].Name) assert.Equal(t, name3, user.Groups[1].Name) assert.Equal(t, name1, user.Groups[2].Name) } a := getTestAdmin() a.Username = altAdminUsername a.Groups = []dataprovider.AdminGroupMapping{ { Name: name3, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, }, }, { Name: name2, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, }, }, { Name: name1, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsMembership, }, }, } admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, admin.Groups, 3) { assert.Equal(t, name3, admin.Groups[0].Name) assert.Equal(t, name2, admin.Groups[1].Name) assert.Equal(t, name1, admin.Groups[2].Name) } admin.Groups = []dataprovider.AdminGroupMapping{ { Name: name1, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, }, }, { Name: name3, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsMembership, }, }, { Name: name2, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, }, }, } admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, admin.Groups, 3) { assert.Equal(t, name1, admin.Groups[0].Name) assert.Equal(t, name3, admin.Groups[1].Name) assert.Equal(t, name2, admin.Groups[2].Name) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group3, http.StatusOK) assert.NoError(t, err) } func TestBasicGroupHandling(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.TLSCerts = []string{"invalid cert"} // ignored for groups group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) assert.Greater(t, group.CreatedAt, int64(0)) assert.Greater(t, group.UpdatedAt, int64(0)) assert.Len(t, group.UserSettings.Filters.TLSCerts, 0) groupGet, _, err := httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, group, groupGet) groups, _, err := httpdtest.GetGroups(0, 0, http.StatusOK) assert.NoError(t, err) if assert.Len(t, groups, 1) { assert.Equal(t, group, groups[0]) } groups, _, err = httpdtest.GetGroups(0, 1, http.StatusOK) assert.NoError(t, err) assert.Len(t, groups, 0) _, _, err = httpdtest.GetGroupByName(group.Name+"_", http.StatusNotFound) assert.NoError(t, err) // adding the same group again should fail _, _, err = httpdtest.AddGroup(g, http.StatusConflict) assert.NoError(t, err) group.UserSettings.HomeDir = filepath.Join(os.TempDir(), "%username%") group.UserSettings.FsConfig.Provider = sdk.SFTPFilesystemProvider group.UserSettings.FsConfig.SFTPConfig.Endpoint = sftpServerAddr group.UserSettings.FsConfig.SFTPConfig.Username = defaultUsername group.UserSettings.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) group.UserSettings.Permissions = map[string][]string{ "/": {dataprovider.PermAny}, } group.UserSettings.Filters.AllowedIP = []string{"10.0.0.0/8"} group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) assert.NoError(t, err) groupGet, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, groupGet.UserSettings.Permissions, 1) assert.Len(t, groupGet.UserSettings.Filters.AllowedIP, 1) // update again and check that the password was preserved dbGroup, err := dataprovider.GroupExists(group.Name) assert.NoError(t, err) group.UserSettings.FsConfig.SFTPConfig.Password = kms.NewSecret( dbGroup.UserSettings.FsConfig.SFTPConfig.Password.GetStatus(), dbGroup.UserSettings.FsConfig.SFTPConfig.Password.GetPayload(), "", "") group.UserSettings.Permissions = nil group.UserSettings.Filters.AllowedIP = nil group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) assert.NoError(t, err) assert.Len(t, group.UserSettings.Permissions, 0) assert.Len(t, group.UserSettings.Filters.AllowedIP, 0) dbGroup, err = dataprovider.GroupExists(group.Name) assert.NoError(t, err) err = dbGroup.UserSettings.FsConfig.SFTPConfig.Password.Decrypt() assert.NoError(t, err) assert.Equal(t, defaultPassword, dbGroup.UserSettings.FsConfig.SFTPConfig.Password.GetPayload()) // check the group permissions groupGet, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, groupGet.UserSettings.Permissions, 0) group.UserSettings.HomeDir = "relative path" _, _, err = httpdtest.UpdateGroup(group, http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.UpdateGroup(group, http.StatusNotFound) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusNotFound) assert.NoError(t, err) } func TestGroupRelations(t *testing.T) { mappedPath1 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) folderName1 := filepath.Base(mappedPath1) mappedPath2 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) folderName2 := filepath.Base(mappedPath2) _, _, err := httpdtest.AddFolder(vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, FsConfig: vfs.Filesystem{ OSConfig: sdk.OSFsConfig{ ReadBufferSize: 3, WriteBufferSize: 5, }, }, }, http.StatusCreated) assert.NoError(t, err) g1 := getTestGroup() g1.Name += "_1" g1.VirtualFolders = append(g1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: "/vdir1", }) g2 := getTestGroup() g2.Name += "_2" g2.VirtualFolders = append(g2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: "/vdir2", }) g3 := getTestGroup() g3.Name += "_3" g3.VirtualFolders = append(g3.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: "/vdir3", }) _, _, err = httpdtest.AddGroup(g1, http.StatusCreated) assert.Error(t, err, "adding a group with a missing folder must fail") _, _, err = httpdtest.AddFolder(vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, }, http.StatusCreated) assert.NoError(t, err) group1, resp, err := httpdtest.AddGroup(g1, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Len(t, group1.VirtualFolders, 1) group2, resp, err := httpdtest.AddGroup(g2, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Len(t, group2.VirtualFolders, 1) group3, resp, err := httpdtest.AddGroup(g3, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Len(t, group3.VirtualFolders, 1) folder1, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder1.Groups, 3) folder2, _, err := httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder2.Groups, 0) group1.VirtualFolders = append(group1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: folder2, VirtualPath: "/vfolder2", }) group1, _, err = httpdtest.UpdateGroup(group1, http.StatusOK) assert.NoError(t, err) assert.Len(t, group1.VirtualFolders, 2) folder2, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder2.Groups, 1) group1.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folder1.Name, MappedPath: folder1.MappedPath, }, VirtualPath: "/vpathmod", }, } group1, _, err = httpdtest.UpdateGroup(group1, http.StatusOK) assert.NoError(t, err) assert.Len(t, group1.VirtualFolders, 1) folder2, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder2.Groups, 0) group1.VirtualFolders = append(group1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: folder2, VirtualPath: "/vfolder2mod", }) group1, _, err = httpdtest.UpdateGroup(group1, http.StatusOK) assert.NoError(t, err) assert.Len(t, group1.VirtualFolders, 2) folder2, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder2.Groups, 1) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: group1.Name, Type: sdk.GroupTypePrimary, }, { Name: group2.Name, Type: sdk.GroupTypeSecondary, }, { Name: group3.Name, Type: sdk.GroupTypeSecondary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) if assert.Len(t, user.Groups, 3) { for _, g := range user.Groups { if g.Name == group1.Name { assert.Equal(t, sdk.GroupTypePrimary, g.Type) } else { assert.Equal(t, sdk.GroupTypeSecondary, g.Type) } } } group1, _, err = httpdtest.GetGroupByName(group1.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group1.Users, 1) group2, _, err = httpdtest.GetGroupByName(group2.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group2.Users, 1) group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group3.Users, 1) user.Groups = []sdk.GroupMapping{ { Name: group3.Name, Type: sdk.GroupTypeSecondary, }, } user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Len(t, user.Groups, 1) group1, _, err = httpdtest.GetGroupByName(group1.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group1.Users, 0) group2, _, err = httpdtest.GetGroupByName(group2.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group2.Users, 0) group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group3.Users, 1) user.Groups = []sdk.GroupMapping{ { Name: group1.Name, Type: sdk.GroupTypePrimary, }, { Name: group2.Name, Type: sdk.GroupTypeSecondary, }, { Name: group3.Name, Type: sdk.GroupTypeSecondary, }, } user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Len(t, user.Groups, 3) group1, _, err = httpdtest.GetGroupByName(group1.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group1.Users, 1) group2, _, err = httpdtest.GetGroupByName(group2.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group2.Users, 1) group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group3.Users, 1) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder1, http.StatusOK) assert.NoError(t, err) group1, _, err = httpdtest.GetGroupByName(group1.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group1.Users, 0) assert.Len(t, group1.VirtualFolders, 1) group2, _, err = httpdtest.GetGroupByName(group2.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group2.Users, 0) assert.Len(t, group2.VirtualFolders, 0) group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group3.Users, 0) assert.Len(t, group3.VirtualFolders, 0) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group3, http.StatusOK) assert.NoError(t, err) folder2, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder2.Groups, 0) _, err = httpdtest.RemoveFolder(folder2, http.StatusOK) assert.NoError(t, err) } func TestGroupValidation(t *testing.T) { group := getTestGroup() group.VirtualFolders = []vfs.VirtualFolder{ { BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: util.GenerateUniqueID(), MappedPath: filepath.Join(os.TempDir(), util.GenerateUniqueID()), }, }, } _, resp, err := httpdtest.AddGroup(group, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "virtual path is mandatory") group.VirtualFolders = nil group.UserSettings.FsConfig.Provider = sdk.SFTPFilesystemProvider _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "endpoint cannot be empty") group.UserSettings.FsConfig.Provider = sdk.LocalFilesystemProvider group.UserSettings.Permissions = map[string][]string{ "a": nil, } _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "cannot set permissions for non absolute path") group.UserSettings.Permissions = map[string][]string{ "/": nil, } _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "no permissions granted") group.UserSettings.Permissions = map[string][]string{ "/..": nil, } _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "cannot set permissions for invalid subdirectory") group.UserSettings.Permissions = map[string][]string{ "/": {dataprovider.PermAny}, } group.UserSettings.HomeDir = "relative" _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "home_dir must be an absolute path") group.UserSettings.HomeDir = "" group.UserSettings.Filters.WebClient = []string{"invalid permission"} _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid web client options") } func TestGroupSettingsOverride(t *testing.T) { mappedPath1 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) folderName1 := filepath.Base(mappedPath1) mappedPath2 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) folderName2 := filepath.Base(mappedPath2) mappedPath3 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) folderName3 := filepath.Base(mappedPath3) g1 := getTestGroup() g1.Name += "_1" g1.VirtualFolders = append(g1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: "/vdir1", }) g1.UserSettings.Permissions = map[string][]string{ "/dir1": {dataprovider.PermUpload}, "/dir2": {dataprovider.PermDownload, dataprovider.PermListItems}, } g1.UserSettings.FsConfig.OSConfig = sdk.OSFsConfig{ ReadBufferSize: 6, WriteBufferSize: 2, } g2 := getTestGroup() g2.Name += "_2" g2.UserSettings.Permissions = map[string][]string{ "/dir1": {dataprovider.PermAny}, "/dir3": {dataprovider.PermDownload, dataprovider.PermListItems, dataprovider.PermChtimes}, } g2.VirtualFolders = append(g2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: "/vdir2", }) g2.VirtualFolders = append(g2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: "/vdir3", }) g2.VirtualFolders = append(g2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName3, }, VirtualPath: "/vdir4", }) g2.UserSettings.Filters.AccessTime = []sdk.TimePeriod{ { DayOfWeek: int(time.Now().UTC().Weekday()), From: "00:00", To: "23:59", }, } f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, FsConfig: vfs.Filesystem{ OSConfig: sdk.OSFsConfig{ ReadBufferSize: 3, WriteBufferSize: 5, }, }, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) f3 := vfs.BaseVirtualFolder{ Name: folderName3, MappedPath: mappedPath3, FsConfig: vfs.Filesystem{ OSConfig: sdk.OSFsConfig{ ReadBufferSize: 1, WriteBufferSize: 2, }, }, } _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) group1, resp, err := httpdtest.AddGroup(g1, http.StatusCreated) assert.NoError(t, err, string(resp)) group2, resp, err := httpdtest.AddGroup(g2, http.StatusCreated) assert.NoError(t, err, string(resp)) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: group1.Name, Type: sdk.GroupTypePrimary, }, { Name: group2.Name, Type: sdk.GroupTypeSecondary, }, } r := getTestRole() role, _, err := httpdtest.AddRole(r, http.StatusCreated) assert.NoError(t, err) u.Role = role.Name user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Len(t, user.VirtualFolders, 0) assert.Len(t, user.Permissions, 1) user, err = dataprovider.CheckUserAndPass(defaultUsername, defaultPassword, "", common.ProtocolHTTP) assert.NoError(t, err) var folderNames []string if assert.Len(t, user.VirtualFolders, 4) { for _, f := range user.VirtualFolders { if !slices.Contains(folderNames, f.Name) { folderNames = append(folderNames, f.Name) } switch f.Name { case folderName1: assert.Equal(t, mappedPath1, f.MappedPath) assert.Equal(t, 3, f.FsConfig.OSConfig.ReadBufferSize) assert.Equal(t, 5, f.FsConfig.OSConfig.WriteBufferSize) assert.True(t, slices.Contains([]string{"/vdir1", "/vdir2"}, f.VirtualPath)) case folderName2: assert.Equal(t, mappedPath2, f.MappedPath) assert.Equal(t, "/vdir3", f.VirtualPath) assert.Equal(t, 0, f.FsConfig.OSConfig.ReadBufferSize) assert.Equal(t, 0, f.FsConfig.OSConfig.WriteBufferSize) case folderName3: assert.Equal(t, mappedPath3, f.MappedPath) assert.Equal(t, "/vdir4", f.VirtualPath) assert.Equal(t, 1, f.FsConfig.OSConfig.ReadBufferSize) assert.Equal(t, 2, f.FsConfig.OSConfig.WriteBufferSize) } } } assert.Len(t, folderNames, 3) assert.Contains(t, folderNames, folderName1) assert.Contains(t, folderNames, folderName2) assert.Contains(t, folderNames, folderName3) assert.Len(t, user.Permissions, 4) assert.Equal(t, g1.UserSettings.Permissions["/dir1"], user.Permissions["/dir1"]) assert.Equal(t, g1.UserSettings.Permissions["/dir2"], user.Permissions["/dir2"]) assert.Equal(t, g2.UserSettings.Permissions["/dir3"], user.Permissions["/dir3"]) assert.Equal(t, g1.UserSettings.FsConfig.OSConfig.ReadBufferSize, user.FsConfig.OSConfig.ReadBufferSize) assert.Equal(t, g1.UserSettings.FsConfig.OSConfig.WriteBufferSize, user.FsConfig.OSConfig.WriteBufferSize) assert.Len(t, user.Filters.AccessTime, 1) user, err = dataprovider.GetUserAfterIDPAuth(defaultUsername, "", common.ProtocolOIDC, nil) assert.NoError(t, err) assert.Len(t, user.VirtualFolders, 4) assert.Len(t, user.Filters.AccessTime, 1) user1, user2, err := dataprovider.GetUserVariants(defaultUsername, "") assert.NoError(t, err) assert.Len(t, user1.VirtualFolders, 0) assert.Len(t, user2.VirtualFolders, 4) assert.Equal(t, int64(0), user1.ExpirationDate) assert.Equal(t, int64(0), user2.ExpirationDate) assert.Len(t, user1.Filters.AccessTime, 0) assert.Len(t, user2.Filters.AccessTime, 1) group2.UserSettings.FsConfig = vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: defaultUsername, }, Password: kms.NewPlainSecret(defaultPassword), }, } group2.UserSettings.Permissions = map[string][]string{ "/": {dataprovider.PermListItems, dataprovider.PermDownload}, "/%username%": {dataprovider.PermListItems}, } group2.UserSettings.DownloadBandwidth = 128 group2.UserSettings.UploadBandwidth = 256 group2.UserSettings.Filters.PasswordStrength = 70 group2.UserSettings.Filters.WebClient = []string{sdk.WebClientInfoChangeDisabled, sdk.WebClientMFADisabled} _, _, err = httpdtest.UpdateGroup(group2, http.StatusOK) assert.NoError(t, err) user, err = dataprovider.CheckUserAndPass(defaultUsername, defaultPassword, "", common.ProtocolHTTP) assert.NoError(t, err) assert.Len(t, user.VirtualFolders, 4) assert.Equal(t, sdk.LocalFilesystemProvider, user.FsConfig.Provider) assert.Equal(t, int64(0), user.DownloadBandwidth) assert.Equal(t, int64(0), user.UploadBandwidth) assert.Equal(t, 0, user.Filters.PasswordStrength) assert.Equal(t, []string{dataprovider.PermAny}, user.GetPermissionsForPath("/")) assert.Equal(t, []string{dataprovider.PermListItems}, user.GetPermissionsForPath("/"+defaultUsername)) assert.Len(t, user.Filters.WebClient, 2) group1.UserSettings.FsConfig = vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: altAdminUsername, Prefix: "/dirs/%role%/%username%", }, Password: kms.NewPlainSecret(defaultPassword), }, } group1.UserSettings.MaxSessions = 2 group1.UserSettings.QuotaFiles = 1000 group1.UserSettings.UploadBandwidth = 512 group1.UserSettings.DownloadBandwidth = 1024 group1.UserSettings.TotalDataTransfer = 2048 group1.UserSettings.ExpiresIn = 15 group1.UserSettings.Filters.MaxUploadFileSize = 1024 * 1024 group1.UserSettings.Filters.StartDirectory = "/startdir/%username%" group1.UserSettings.Filters.PasswordStrength = 70 group1.UserSettings.Filters.WebClient = []string{sdk.WebClientInfoChangeDisabled} group1.UserSettings.Permissions = map[string][]string{ "/": {dataprovider.PermListItems, dataprovider.PermUpload}, "/sub/%username%": {dataprovider.PermRename}, "/%role%/%username%": {dataprovider.PermDelete}, } group1.UserSettings.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/sub2/%role%/%username%test", AllowedPatterns: []string{}, DeniedPatterns: []string{"*.jpg", "*.zip"}, }, } _, _, err = httpdtest.UpdateGroup(group1, http.StatusOK) assert.NoError(t, err) user, err = dataprovider.CheckUserAndPass(defaultUsername, defaultPassword, "", common.ProtocolHTTP) assert.NoError(t, err) assert.Len(t, user.VirtualFolders, 4) assert.Equal(t, user.CreatedAt+int64(group1.UserSettings.ExpiresIn)*86400000, user.ExpirationDate) assert.Equal(t, group1.UserSettings.Filters.PasswordStrength, user.Filters.PasswordStrength) assert.Equal(t, sdk.SFTPFilesystemProvider, user.FsConfig.Provider) assert.Equal(t, altAdminUsername, user.FsConfig.SFTPConfig.Username) assert.Equal(t, "/dirs/"+role.Name+"/"+defaultUsername, user.FsConfig.SFTPConfig.Prefix) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermUpload}, user.GetPermissionsForPath("/")) assert.Equal(t, []string{dataprovider.PermDelete}, user.GetPermissionsForPath(path.Join("/", role.Name, defaultUsername))) assert.Equal(t, []string{dataprovider.PermRename}, user.GetPermissionsForPath(path.Join("/sub", defaultUsername))) assert.Equal(t, group1.UserSettings.MaxSessions, user.MaxSessions) assert.Equal(t, group1.UserSettings.QuotaFiles, user.QuotaFiles) assert.Equal(t, group1.UserSettings.UploadBandwidth, user.UploadBandwidth) assert.Equal(t, group1.UserSettings.TotalDataTransfer, user.TotalDataTransfer) assert.Equal(t, group1.UserSettings.Filters.MaxUploadFileSize, user.Filters.MaxUploadFileSize) assert.Equal(t, "/startdir/"+defaultUsername, user.Filters.StartDirectory) if assert.Len(t, user.Filters.FilePatterns, 1) { assert.Equal(t, "/sub2/"+role.Name+"/"+defaultUsername+"test", user.Filters.FilePatterns[0].Path) //nolint:goconst } if assert.Len(t, user.Filters.WebClient, 2) { assert.Contains(t, user.Filters.WebClient, sdk.WebClientInfoChangeDisabled) assert.Contains(t, user.Filters.WebClient, sdk.WebClientMFADisabled) } // Attempt to create a user with a weak password and group1 as the primary group: this should fail u = getTestUser() u.Username = rand.Text() u.Password = defaultPassword u.Groups = []sdk.GroupMapping{ { Name: group1.Name, Type: sdk.GroupTypePrimary, }, } _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "insecure password") err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName3}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) } func TestConfigs(t *testing.T) { err := dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) configs, err := dataprovider.GetConfigs() assert.NoError(t, err) assert.Equal(t, int64(0), configs.UpdatedAt) assert.Nil(t, configs.SFTPD) assert.Nil(t, configs.SMTP) configs = dataprovider.Configs{ SFTPD: &dataprovider.SFTPDConfigs{}, SMTP: &dataprovider.SMTPConfigs{}, } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Greater(t, configs.UpdatedAt, int64(0)) configs = dataprovider.Configs{ SFTPD: &dataprovider.SFTPDConfigs{ Ciphers: []string{"unknown"}, }, SMTP: &dataprovider.SMTPConfigs{}, } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.ErrorIs(t, err, util.ErrValidation) configs = dataprovider.Configs{ SFTPD: &dataprovider.SFTPDConfigs{}, SMTP: &dataprovider.SMTPConfigs{ Host: "smtp.example.com", Port: -1, }, } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.ErrorIs(t, err, util.ErrValidation) configs = dataprovider.Configs{ SMTP: &dataprovider.SMTPConfigs{ Host: "mail.example.com", Port: 587, User: "test@example.com", AuthType: 3, Encryption: 2, OAuth2: dataprovider.SMTPOAuth2{ Provider: 1, Tenant: "", ClientID: "", }, }, } err = dataprovider.UpdateConfigs(&configs, "", "", "") if assert.ErrorIs(t, err, util.ErrValidation) { assert.Contains(t, err.Error(), "smtp oauth2: client id is required") } configs.SMTP.OAuth2 = dataprovider.SMTPOAuth2{ Provider: 1, ClientID: "client id", ClientSecret: kms.NewPlainSecret("client secret"), RefreshToken: kms.NewPlainSecret("refresh token"), } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Equal(t, 3, configs.SMTP.AuthType) assert.Equal(t, 1, configs.SMTP.OAuth2.Provider) err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestBasicIPListEntriesHandling(t *testing.T) { entry := dataprovider.IPListEntry{ IPOrNet: "::ffff:12.34.56.78", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Description: "test desc", } _, _, err := httpdtest.GetIPListEntry(entry.IPOrNet, -1, http.StatusBadRequest) assert.NoError(t, err) _, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusNotFound) assert.NoError(t, err) _, _, err = httpdtest.AddIPListEntry(entry, http.StatusCreated) assert.Error(t, err) // IPv4 address in IPv6 will be converted to standard IPv4 entry1, _, err := httpdtest.GetIPListEntry("12.34.56.78/32", dataprovider.IPListTypeAllowList, http.StatusOK) assert.NoError(t, err) entry = dataprovider.IPListEntry{ IPOrNet: "192.168.0.0/24", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, } entry2, _, err := httpdtest.AddIPListEntry(entry, http.StatusCreated) assert.NoError(t, err) // adding the same entry again should fail _, _, err = httpdtest.AddIPListEntry(entry, http.StatusConflict) assert.NoError(t, err) // adding an entry with an invalid IP should fail entry.IPOrNet = "not valid" _, _, err = httpdtest.AddIPListEntry(entry, http.StatusBadRequest) assert.NoError(t, err) // adding an entry with an incompatible mode should fail entry.IPOrNet = entry2.IPOrNet entry.Mode = -1 _, _, err = httpdtest.AddIPListEntry(entry, http.StatusBadRequest) assert.NoError(t, err) entry.Type = -1 _, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusBadRequest) assert.NoError(t, err) entry = dataprovider.IPListEntry{ IPOrNet: "2001:4860:4860::8888/120", Type: dataprovider.IPListTypeRateLimiterSafeList, Mode: dataprovider.ListModeDeny, } _, _, err = httpdtest.AddIPListEntry(entry, http.StatusBadRequest) assert.NoError(t, err) entry.Mode = dataprovider.ListModeAllow _, _, err = httpdtest.AddIPListEntry(entry, http.StatusCreated) assert.NoError(t, err) entry.Protocols = 3 entry3, _, err := httpdtest.UpdateIPListEntry(entry, http.StatusOK) assert.NoError(t, err) entry.Mode = dataprovider.ListModeDeny _, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusBadRequest) assert.NoError(t, err) for _, tt := range []dataprovider.IPListType{dataprovider.IPListTypeAllowList, dataprovider.IPListTypeDefender, dataprovider.IPListTypeRateLimiterSafeList} { entries, _, err := httpdtest.GetIPListEntries(tt, "", "", dataprovider.OrderASC, 0, http.StatusOK) assert.NoError(t, err) if assert.Len(t, entries, 1) { switch tt { case dataprovider.IPListTypeAllowList: assert.Equal(t, entry1, entries[0]) case dataprovider.IPListTypeDefender: assert.Equal(t, entry2, entries[0]) case dataprovider.IPListTypeRateLimiterSafeList: assert.Equal(t, entry3, entries[0]) } } } _, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "", "invalid order", 0, http.StatusBadRequest) assert.NoError(t, err) _, _, err = httpdtest.GetIPListEntries(-1, "", "", dataprovider.OrderASC, 0, http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.RemoveIPListEntry(entry1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveIPListEntry(entry1, http.StatusNotFound) assert.NoError(t, err) _, err = httpdtest.RemoveIPListEntry(entry2, http.StatusOK) assert.NoError(t, err) entry2.Type = -1 _, err = httpdtest.RemoveIPListEntry(entry2, http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.RemoveIPListEntry(entry3, http.StatusOK) assert.NoError(t, err) } func TestSearchIPListEntries(t *testing.T) { entries := []dataprovider.IPListEntry{ { IPOrNet: "192.168.0.0/24", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 0, }, { IPOrNet: "192.168.0.1/24", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 0, }, { IPOrNet: "192.168.0.2/24", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 5, }, { IPOrNet: "192.168.0.3/24", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 8, }, { IPOrNet: "10.8.0.0/24", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 3, }, { IPOrNet: "10.8.1.0/24", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 8, }, { IPOrNet: "10.8.2.0/24", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 1, }, } for _, e := range entries { _, _, err := httpdtest.AddIPListEntry(e, http.StatusCreated) assert.NoError(t, err) } results, _, err := httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "", dataprovider.OrderASC, 20, http.StatusOK) assert.NoError(t, err) if assert.Equal(t, len(entries), len(results)) { assert.Equal(t, "10.8.0.0/24", results[0].IPOrNet) } results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "", dataprovider.OrderDESC, 20, http.StatusOK) assert.NoError(t, err) if assert.Equal(t, len(entries), len(results)) { assert.Equal(t, "192.168.0.3/24", results[0].IPOrNet) } results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "192.168.0.1/24", dataprovider.OrderASC, 1, http.StatusOK) assert.NoError(t, err) if assert.Equal(t, 1, len(results), results) { assert.Equal(t, "192.168.0.2/24", results[0].IPOrNet) } results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "10.8.2.0/24", dataprovider.OrderDESC, 1, http.StatusOK) assert.NoError(t, err) if assert.Equal(t, 1, len(results), results) { assert.Equal(t, "10.8.1.0/24", results[0].IPOrNet) } results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "10.", "", dataprovider.OrderASC, 20, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, len(results)) results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "192", "", dataprovider.OrderASC, 20, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 4, len(results)) results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "1", "", dataprovider.OrderASC, 20, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 7, len(results)) results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "108", "", dataprovider.OrderASC, 20, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, len(results)) for _, e := range entries { _, err := httpdtest.RemoveIPListEntry(e, http.StatusOK) assert.NoError(t, err) } } func TestIPListEntriesValidation(t *testing.T) { entry := dataprovider.IPListEntry{ IPOrNet: "::ffff:34.56.78.90/120", Type: -1, Mode: dataprovider.ListModeDeny, } _, resp, err := httpdtest.AddIPListEntry(entry, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid list type") entry.Type = dataprovider.IPListTypeRateLimiterSafeList _, resp, err = httpdtest.AddIPListEntry(entry, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid list mode") entry.Type = dataprovider.IPListTypeDefender _, _, err = httpdtest.AddIPListEntry(entry, http.StatusCreated) assert.Error(t, err) entry.IPOrNet = "34.56.78.0/24" _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) assert.NoError(t, err) } func TestBasicActionRulesHandling(t *testing.T) { actionName := "test_action" a := dataprovider.BaseEventAction{ Name: actionName, Description: "test description", Type: dataprovider.ActionTypeBackup, Options: dataprovider.BaseEventActionOptions{}, } action, _, err := httpdtest.AddEventAction(a, http.StatusCreated) assert.NoError(t, err) // adding the same action should fail _, _, err = httpdtest.AddEventAction(a, http.StatusConflict) assert.NoError(t, err) actionGet, _, err := httpdtest.GetEventActionByName(actionName, http.StatusOK) assert.NoError(t, err) actions, _, err := httpdtest.GetEventActions(0, 0, http.StatusOK) assert.NoError(t, err) assert.Greater(t, len(actions), 0) found := false for _, ac := range actions { if ac.Name == actionName { assert.Equal(t, actionGet, ac) found = true } } assert.True(t, found) a.Description = "new description" a.Type = dataprovider.ActionTypeDataRetentionCheck a.Options = dataprovider.BaseEventActionOptions{ RetentionConfig: dataprovider.EventActionDataRetentionConfig{ Folders: []dataprovider.FolderRetention{ { Path: "/", Retention: 144, }, { Path: "/p1", Retention: 0, }, { Path: "/p2", Retention: 12, }, }, }, } _, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) assert.NoError(t, err) a.Type = dataprovider.ActionTypeCommand a.Options = dataprovider.BaseEventActionOptions{ CmdConfig: dataprovider.EventActionCommandConfig{ Cmd: filepath.Join(os.TempDir(), "test_cmd"), Timeout: 20, EnvVars: []dataprovider.KeyValue{ { Key: "NAME", Value: "VALUE", }, }, }, } dataprovider.EnabledActionCommands = []string{a.Options.CmdConfig.Cmd} defer func() { dataprovider.EnabledActionCommands = nil }() _, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) assert.NoError(t, err) // invalid type a.Type = 1000 _, _, err = httpdtest.UpdateEventAction(a, http.StatusBadRequest) assert.NoError(t, err) a.Type = dataprovider.ActionTypeEmail a.Options = dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"email@example.com"}, Bcc: []string{"bcc@example.com"}, Subject: "Event: {{.Event}}", Body: "test mail body", Attachments: []string{"/{{.VirtualPath}}"}, }, } _, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) assert.NoError(t, err) a.Type = dataprovider.ActionTypeUserInactivityCheck a.Options = dataprovider.BaseEventActionOptions{ UserInactivityConfig: dataprovider.EventActionUserInactivity{ DisableThreshold: 10, DeleteThreshold: 20, }, } _, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) assert.NoError(t, err) a.Type = dataprovider.ActionTypeHTTP a.Options = dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: "https://localhost:1234", Username: defaultUsername, Password: kms.NewPlainSecret(defaultPassword), Headers: []dataprovider.KeyValue{ { Key: "Content-Type", Value: "application/json", }, }, Timeout: 10, SkipTLSVerify: true, Method: http.MethodPost, QueryParameters: []dataprovider.KeyValue{ { Key: "a", Value: "b", }, }, Body: `{"event":"{{.Event}}","name":"{{.Name}}"}`, }, } action, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, action.Options.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, action.Options.HTTPConfig.Password.GetPayload()) assert.Empty(t, action.Options.HTTPConfig.Password.GetKey()) assert.Empty(t, action.Options.HTTPConfig.Password.GetAdditionalData()) // update again and check that the password was preserved dbAction, err := dataprovider.EventActionExists(actionName) assert.NoError(t, err) action.Options.HTTPConfig.Password = kms.NewSecret( dbAction.Options.HTTPConfig.Password.GetStatus(), dbAction.Options.HTTPConfig.Password.GetPayload(), "", "") action, _, err = httpdtest.UpdateEventAction(action, http.StatusOK) assert.NoError(t, err) dbAction, err = dataprovider.EventActionExists(actionName) assert.NoError(t, err) err = dbAction.Options.HTTPConfig.Password.Decrypt() assert.NoError(t, err) assert.Equal(t, defaultPassword, dbAction.Options.HTTPConfig.Password.GetPayload()) r := dataprovider.EventRule{ Name: "test_rule_name", Status: 1, Description: "", Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"upload"}, Options: dataprovider.ConditionOptions{ EventStatuses: []int{2, 3}, MinFileSize: 1024 * 1024, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: actionName, }, Order: 1, Options: dataprovider.EventActionOptions{ IsFailureAction: false, StopOnFailure: true, ExecuteSync: true, }, }, }, } rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) assert.NoError(t, err) // adding the same rule should fail _, _, err = httpdtest.AddEventRule(r, http.StatusConflict) assert.NoError(t, err) rule.Description = "new rule desc" rule.Trigger = 1000 _, _, err = httpdtest.UpdateEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) rule.Trigger = dataprovider.EventTriggerFsEvent rule, _, err = httpdtest.UpdateEventRule(rule, http.StatusOK) assert.NoError(t, err) ruleGet, _, err := httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) assert.NoError(t, err) if assert.Len(t, ruleGet.Actions, 1) { if assert.NotNil(t, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password) { assert.Equal(t, sdkkms.SecretStatusSecretBox, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetPayload()) assert.Empty(t, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetKey()) assert.Empty(t, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetAdditionalData()) } } rules, _, err := httpdtest.GetEventRules(0, 0, http.StatusOK) assert.NoError(t, err) assert.Greater(t, len(rules), 0) found = false for _, ru := range rules { if ru.Name == rule.Name { assert.Equal(t, ruleGet, ru) found = true } } assert.True(t, found) _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.UpdateEventRule(rule, http.StatusNotFound) assert.NoError(t, err) _, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusNotFound) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule, http.StatusNotFound) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.UpdateEventAction(action, http.StatusNotFound) assert.NoError(t, err) _, _, err = httpdtest.GetEventActionByName(actionName, http.StatusNotFound) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action, http.StatusNotFound) assert.NoError(t, err) } func TestActionRuleRelations(t *testing.T) { a1 := dataprovider.BaseEventAction{ Name: "action1", Description: "test description", Type: dataprovider.ActionTypeBackup, Options: dataprovider.BaseEventActionOptions{}, } a2 := dataprovider.BaseEventAction{ Name: "action2", Type: dataprovider.ActionTypeTransferQuotaReset, Options: dataprovider.BaseEventActionOptions{}, } a3 := dataprovider.BaseEventAction{ Name: "action3", Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test@example.net"}, ContentType: 1, Subject: "test subject", Body: "test body", }, }, } action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) assert.NoError(t, err) action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) assert.NoError(t, err) action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) assert.NoError(t, err) r1 := dataprovider.EventRule{ Name: "rule1", Description: "", Trigger: dataprovider.EventTriggerProviderEvent, Conditions: dataprovider.EventConditions{ ProviderEvents: []string{"add"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 2, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action1.Name, }, Order: 1, }, }, } rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) assert.NoError(t, err) if assert.Len(t, rule1.Actions, 2) { assert.Equal(t, action1.Name, rule1.Actions[0].Name) assert.Equal(t, 1, rule1.Actions[0].Order) assert.Equal(t, action3.Name, rule1.Actions[1].Name) assert.Equal(t, 2, rule1.Actions[1].Order) assert.True(t, rule1.Actions[1].Options.IsFailureAction) } r2 := dataprovider.EventRule{ Name: "rule2", Description: "", Trigger: dataprovider.EventTriggerSchedule, Conditions: dataprovider.EventConditions{ Schedules: []dataprovider.Schedule{ { Hours: "1", DayOfWeek: "*", DayOfMonth: "*", Month: "*", }, }, Options: dataprovider.ConditionOptions{ RoleNames: []dataprovider.ConditionPattern{ { Pattern: "g*", InverseMatch: true, }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action3.Name, }, Order: 2, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 1, }, }, } rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) assert.NoError(t, err) if assert.Len(t, rule1.Actions, 2) { assert.Equal(t, action2.Name, rule2.Actions[0].Name) assert.Equal(t, 1, rule2.Actions[0].Order) assert.Equal(t, action3.Name, rule2.Actions[1].Name) assert.Equal(t, 2, rule2.Actions[1].Order) assert.True(t, rule2.Actions[1].Options.IsFailureAction) } // check the references action1, _, err = httpdtest.GetEventActionByName(action1.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, action1.Rules, 1) assert.True(t, slices.Contains(action1.Rules, rule1.Name)) action2, _, err = httpdtest.GetEventActionByName(action2.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, action2.Rules, 1) assert.True(t, slices.Contains(action2.Rules, rule2.Name)) action3, _, err = httpdtest.GetEventActionByName(action3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, action3.Rules, 2) assert.True(t, slices.Contains(action3.Rules, rule1.Name)) assert.True(t, slices.Contains(action3.Rules, rule2.Name)) // referenced actions cannot be removed _, err = httpdtest.RemoveEventAction(action1, http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action3, http.StatusBadRequest) assert.NoError(t, err) // remove action3 from rule2 r2.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action2.Name, }, Order: 10, }, } rule2.Status = 1 rule2, _, err = httpdtest.UpdateEventRule(r2, http.StatusOK) assert.NoError(t, err) if assert.Len(t, rule2.Actions, 1) { assert.Equal(t, action2.Name, rule2.Actions[0].Name) assert.Equal(t, 10, rule2.Actions[0].Order) } // check the updated relation action3, _, err = httpdtest.GetEventActionByName(action3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, action3.Rules, 1) assert.True(t, slices.Contains(action3.Rules, rule1.Name)) _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) assert.NoError(t, err) // no relations anymore action1, _, err = httpdtest.GetEventActionByName(action1.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, action1.Rules, 0) action2, _, err = httpdtest.GetEventActionByName(action2.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, action2.Rules, 0) action3, _, err = httpdtest.GetEventActionByName(action3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, action3.Rules, 0) _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) assert.NoError(t, err) } func TestOnDemandEventRules(t *testing.T) { ruleName := "test_on_demand_rule" a := dataprovider.BaseEventAction{ Name: "a", Type: dataprovider.ActionTypeBackup, Options: dataprovider.BaseEventActionOptions{}, } action, _, err := httpdtest.AddEventAction(a, http.StatusCreated) assert.NoError(t, err) r := dataprovider.EventRule{ Name: ruleName, Status: 1, Trigger: dataprovider.EventTriggerOnDemand, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: a.Name, }, }, }, } rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) assert.NoError(t, err) _, err = httpdtest.RunOnDemandRule(ruleName, http.StatusAccepted) assert.NoError(t, err) rule.Status = 0 _, _, err = httpdtest.UpdateEventRule(rule, http.StatusOK) assert.NoError(t, err) resp, err := httpdtest.RunOnDemandRule(ruleName, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "is inactive") _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RunOnDemandRule(ruleName, http.StatusNotFound) assert.NoError(t, err) } func TestIDPLoginEventRule(t *testing.T) { ruleName := "test_IDP_login_rule" a := dataprovider.BaseEventAction{ Name: "a", Type: dataprovider.ActionTypeIDPAccountCheck, Options: dataprovider.BaseEventActionOptions{ IDPConfig: dataprovider.EventActionIDPAccountCheck{ Mode: 1, TemplateUser: `{"username": "user"}`, TemplateAdmin: `{"username": "admin"}`, }, }, } action, resp, err := httpdtest.AddEventAction(a, http.StatusCreated) assert.NoError(t, err, string(resp)) r := dataprovider.EventRule{ Name: ruleName, Status: 1, Trigger: dataprovider.EventTriggerIDPLogin, Conditions: dataprovider.EventConditions{ IDPLoginEvent: 1, Options: dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "username", }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: a.Name, }, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) assert.NoError(t, err) rule.Status = 0 _, _, err = httpdtest.UpdateEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action, http.StatusOK) assert.NoError(t, err) } func TestEventActionValidation(t *testing.T) { action := dataprovider.BaseEventAction{ Name: "", } _, resp, err := httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "name is mandatory") action = dataprovider.BaseEventAction{ Name: "n", Type: -1, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid action type") action.Type = dataprovider.ActionTypeHTTP _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "HTTP endpoint is required") action.Options.HTTPConfig.Endpoint = "abc" _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid HTTP endpoint schema") action.Options.HTTPConfig.Endpoint = "http://localhost" _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid HTTP timeout") action.Options.HTTPConfig.Timeout = 20 action.Options.HTTPConfig.Headers = []dataprovider.KeyValue{ { Key: "", Value: "", }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid HTTP headers") action.Options.HTTPConfig.Headers = []dataprovider.KeyValue{ { Key: "Content-Type", Value: "application/json", }, } action.Options.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusRedacted, "payload", "", "") _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "cannot save HTTP configuration with a redacted secret") action.Options.HTTPConfig.Password = nil action.Options.HTTPConfig.Method = http.MethodTrace _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "unsupported HTTP method") action.Options.HTTPConfig.Method = http.MethodGet action.Options.HTTPConfig.QueryParameters = []dataprovider.KeyValue{ { Key: "a", Value: "", }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid HTTP query parameters") action.Options.HTTPConfig.QueryParameters = nil action.Options.HTTPConfig.Parts = []dataprovider.HTTPPart{ { Name: "", }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "HTTP part name is required") action.Options.HTTPConfig.Parts = []dataprovider.HTTPPart{ { Name: "p1", }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "HTTP part body is required if no file path is provided") action.Options.HTTPConfig.Parts = []dataprovider.HTTPPart{ { Name: "p1", Filepath: "p", }, } action.Options.HTTPConfig.Body = "b" _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "multipart requests require no body") action.Options.HTTPConfig.Body = "" action.Options.HTTPConfig.Headers = []dataprovider.KeyValue{ { Key: "Content-Type", Value: "application/json", }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "content type is automatically set for multipart requests") action.Type = dataprovider.ActionTypeCommand _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "command is required") action.Options.CmdConfig.Cmd = "relative" dataprovider.EnabledActionCommands = []string{action.Options.CmdConfig.Cmd} defer func() { dataprovider.EnabledActionCommands = nil }() _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid command, it must be an absolute path") action.Options.CmdConfig.Cmd = filepath.Join(os.TempDir(), "cmd") _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "is not allowed") dataprovider.EnabledActionCommands = []string{action.Options.CmdConfig.Cmd} _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid command action timeout") action.Options.CmdConfig.Timeout = 30 action.Options.CmdConfig.EnvVars = []dataprovider.KeyValue{ { Key: "k", }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid command env vars") action.Options.CmdConfig.EnvVars = nil action.Options.CmdConfig.Args = []string{"arg1", ""} _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid command args") action.Options.CmdConfig.Args = nil // restrict commands if runtime.GOOS == osWindows { dataprovider.EnabledActionCommands = []string{"C:\\cmd.exe"} } else { dataprovider.EnabledActionCommands = []string{"/bin/sh"} } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "is not allowed") dataprovider.EnabledActionCommands = nil action.Type = dataprovider.ActionTypeEmail _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "at least one email recipient is required") action.Options.EmailConfig.Recipients = []string{""} _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid email recipients") action.Options.EmailConfig.Recipients = []string{"a@a.com"} action.Options.EmailConfig.Bcc = []string{""} _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid email bcc") action.Options.EmailConfig.Bcc = nil _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "email subject is required") action.Options.EmailConfig.Subject = "subject" _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "email body is required") action.Type = dataprovider.ActionTypeDataRetentionCheck action.Options.RetentionConfig = dataprovider.EventActionDataRetentionConfig{ Folders: nil, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "nothing to delete") action.Options.RetentionConfig = dataprovider.EventActionDataRetentionConfig{ Folders: []dataprovider.FolderRetention{ { Path: "/", Retention: 0, }, }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "nothing to delete") action.Options.RetentionConfig = dataprovider.EventActionDataRetentionConfig{ Folders: []dataprovider.FolderRetention{ { Path: "../path", Retention: 1, }, { Path: "/path", Retention: 10, }, }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "duplicated folder path") action.Options.RetentionConfig = dataprovider.EventActionDataRetentionConfig{ Folders: []dataprovider.FolderRetention{ { Path: "p", Retention: -1, }, }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid folder retention") action.Type = dataprovider.ActionTypeFilesystem action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "no path to rename specified") action.Options.FsConfig.Renames = []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "", Value: "/adir", }, }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid paths to rename") action.Options.FsConfig.Renames = []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "adir", Value: "/adir", }, }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "rename source and target cannot be equal") action.Options.FsConfig.Renames = []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/", Value: "/dir", }, }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "renaming the root directory is not allowed") action.Options.FsConfig.Type = dataprovider.FilesystemActionMkdirs _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "no directory to create specified") action.Options.FsConfig.MkDirs = []string{"dir1", ""} _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid directory to create") action.Options.FsConfig.Type = dataprovider.FilesystemActionDelete _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "no path to delete specified") action.Options.FsConfig.Deletes = []string{"item1", ""} _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid path to delete") action.Options.FsConfig.Type = dataprovider.FilesystemActionExist _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "no path to check for existence specified") action.Options.FsConfig.Exist = []string{"item1", ""} _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid path to check for existence") action.Options.FsConfig.Type = dataprovider.FilesystemActionCompress _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "archive name is mandatory") action.Options.FsConfig.Compress.Name = "archive.zip" _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "no path to compress specified") action.Options.FsConfig.Compress.Paths = []string{"item1", ""} _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid path to compress") action.Type = dataprovider.ActionTypePasswordExpirationCheck action.Options.PwdExpirationConfig.Threshold = 0 _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "threshold must be greater than 0") action.Type = dataprovider.ActionTypeIDPAccountCheck _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "at least a template must be set") action.Options.IDPConfig.TemplateAdmin = "{}" action.Options.IDPConfig.Mode = 100 _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid account check mode") action.Type = dataprovider.ActionTypeUserInactivityCheck action.Options = dataprovider.BaseEventActionOptions{ UserInactivityConfig: dataprovider.EventActionUserInactivity{ DisableThreshold: 0, DeleteThreshold: 0, }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "at least a threshold must be defined") action.Options = dataprovider.BaseEventActionOptions{ UserInactivityConfig: dataprovider.EventActionUserInactivity{ DisableThreshold: 10, DeleteThreshold: 10, }, } _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "must be greater than deactivation threshold") } func TestEventRuleValidation(t *testing.T) { rule := dataprovider.EventRule{ Name: "", } _, resp, err := httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "name is mandatory") rule.Name = "r" rule.Status = 100 _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid event rule status") rule.Status = 1 rule.Trigger = 1000 _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid event rule trigger") rule.Trigger = dataprovider.EventTriggerFsEvent _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "at least one filesystem event is required") rule.Conditions.FsEvents = []string{""} _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "unsupported fs event") rule.Conditions.FsEvents = []string{"upload"} _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "at least one action is required") rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action1", }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: "", }, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "name not specified") rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action", }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: "action", }, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "duplicated action") rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action11", }, Order: 1, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: "action12", }, Order: 1, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "duplicated order") rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action111", }, Order: 1, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, { BaseEventAction: dataprovider.BaseEventAction{ Name: "action112", }, Order: 2, Options: dataprovider.EventActionOptions{ IsFailureAction: true, }, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "at least a non-failure action is required") rule.Conditions.FsEvents = []string{"upload", "download"} rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action111", }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "sync execution is only supported for upload and pre-* events") rule.Conditions.FsEvents = []string{"pre-upload", "download"} rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action", }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: false, }, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "requires at least a sync action") rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action", }, Order: 1, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "sync execution is only supported for upload and pre-* events") rule.Conditions.FsEvents = []string{"download"} rule.Conditions.Options.EventStatuses = []int{3, 2, 8} rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action", }, Order: 1, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid event_status") rule.Trigger = dataprovider.EventTriggerProviderEvent rule.Actions = []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "action1234", }, Order: 1, Options: dataprovider.EventActionOptions{ IsFailureAction: false, }, }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "at least one provider event is required") rule.Conditions.ProviderEvents = []string{""} _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "unsupported provider event") rule.Conditions.ProviderEvents = []string{"add"} rule.Conditions.Options.RoleNames = []dataprovider.ConditionPattern{ { Pattern: "", }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "empty condition pattern not allowed") rule.Conditions.Options.RoleNames = nil rule.Trigger = dataprovider.EventTriggerSchedule _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "at least one schedule is required") rule.Conditions.Schedules = []dataprovider.Schedule{ {}, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid schedule") rule.Conditions.Schedules = []dataprovider.Schedule{ { Hours: "3", DayOfWeek: "*", DayOfMonth: "*", Month: "*", }, } _, resp, err = httpdtest.AddEventRule(rule, http.StatusInternalServerError) assert.NoError(t, err, string(resp)) rule.Trigger = dataprovider.EventTriggerIDPLogin rule.Conditions.IDPLoginEvent = 100 _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "invalid Identity Provider login event") } func TestUserBandwidthLimits(t *testing.T) { u := getTestUser() u.UploadBandwidth = 128 u.DownloadBandwidth = 96 u.Filters.BandwidthLimits = []sdk.BandwidthLimit{ { Sources: []string{"1"}, }, } _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "Validation error: could not parse bandwidth limit source") u.Filters.BandwidthLimits = []sdk.BandwidthLimit{ { Sources: nil, }, } _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "Validation error: no bandwidth limit source specified") u.Filters.BandwidthLimits = []sdk.BandwidthLimit{ { Sources: []string{"127.0.0.0/8", "::1/128"}, UploadBandwidth: 256, }, { Sources: []string{"10.0.0.0/8"}, UploadBandwidth: 512, DownloadBandwidth: 256, }, } user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Len(t, user.Filters.BandwidthLimits, 2) assert.Equal(t, u.Filters.BandwidthLimits, user.Filters.BandwidthLimits) connID := xid.New().String() localAddr := "127.0.0.1" up, down := user.GetBandwidthForIP("127.0.1.1", connID) assert.Equal(t, int64(256), up) assert.Equal(t, int64(0), down) conn := common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "127.0.1.1", user) assert.Equal(t, int64(256), conn.User.UploadBandwidth) assert.Equal(t, int64(0), conn.User.DownloadBandwidth) up, down = user.GetBandwidthForIP("10.1.2.3", connID) assert.Equal(t, int64(512), up) assert.Equal(t, int64(256), down) conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "10.2.1.4:1234", user) assert.Equal(t, int64(512), conn.User.UploadBandwidth) assert.Equal(t, int64(256), conn.User.DownloadBandwidth) up, down = user.GetBandwidthForIP("192.168.1.2", connID) assert.Equal(t, int64(128), up) assert.Equal(t, int64(96), down) conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "172.16.0.1", user) assert.Equal(t, int64(128), conn.User.UploadBandwidth) assert.Equal(t, int64(96), conn.User.DownloadBandwidth) up, down = user.GetBandwidthForIP("invalid", connID) assert.Equal(t, int64(128), up) assert.Equal(t, int64(96), down) conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "172.16.0", user) assert.Equal(t, int64(128), conn.User.UploadBandwidth) assert.Equal(t, int64(96), conn.User.DownloadBandwidth) user.Filters.BandwidthLimits = []sdk.BandwidthLimit{ { Sources: []string{"10.0.0.0/24"}, UploadBandwidth: 256, DownloadBandwidth: 512, }, } user, resp, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(resp)) if assert.Len(t, user.Filters.BandwidthLimits, 1) { bwLimit := user.Filters.BandwidthLimits[0] assert.Equal(t, []string{"10.0.0.0/24"}, bwLimit.Sources) assert.Equal(t, int64(256), bwLimit.UploadBandwidth) assert.Equal(t, int64(512), bwLimit.DownloadBandwidth) } up, down = user.GetBandwidthForIP("10.1.2.3", connID) assert.Equal(t, int64(128), up) assert.Equal(t, int64(96), down) conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "172.16.0.2", user) assert.Equal(t, int64(128), conn.User.UploadBandwidth) assert.Equal(t, int64(96), conn.User.DownloadBandwidth) up, down = user.GetBandwidthForIP("10.0.0.26", connID) assert.Equal(t, int64(256), up) assert.Equal(t, int64(512), down) conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "10.0.0.28", user) assert.Equal(t, int64(256), conn.User.UploadBandwidth) assert.Equal(t, int64(512), conn.User.DownloadBandwidth) // this works if we remove the omitempty tag from BandwidthLimits /*user.Filters.BandwidthLimits = nil user, resp, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(resp)) assert.Len(t, user.Filters.BandwidthLimits, 0)*/ err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestAccessTimeValidation(t *testing.T) { u := getTestUser() u.Filters.AccessTime = []sdk.TimePeriod{ { DayOfWeek: 8, From: "10:00", To: "18:00", }, } _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "invalid day of week") u.Filters.AccessTime = []sdk.TimePeriod{ { DayOfWeek: 6, From: "10:00", To: "18", }, } _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "invalid time of day") u.Filters.AccessTime = []sdk.TimePeriod{ { DayOfWeek: 6, From: "11:00", To: "10:58", }, } _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "The end time cannot be earlier than the start time") } func TestUserTimestamps(t *testing.T) { user, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err, string(resp)) createdAt := user.CreatedAt updatedAt := user.UpdatedAt assert.Equal(t, int64(0), user.LastLogin) assert.Equal(t, int64(0), user.FirstDownload) assert.Equal(t, int64(0), user.FirstUpload) assert.Greater(t, createdAt, int64(0)) assert.Greater(t, updatedAt, int64(0)) mappedPath := filepath.Join(os.TempDir(), "mapped_dir") folderName := filepath.Base(mappedPath) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, }, VirtualPath: "/vdir", }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err = httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) time.Sleep(10 * time.Millisecond) user, resp, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(resp)) assert.Equal(t, int64(0), user.LastLogin) assert.Equal(t, int64(0), user.FirstDownload) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, createdAt, user.CreatedAt) assert.Greater(t, user.UpdatedAt, updatedAt) updatedAt = user.UpdatedAt // after a folder update or delete the user updated_at field should change folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) time.Sleep(10 * time.Millisecond) _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.LastLogin) assert.Equal(t, int64(0), user.FirstDownload) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, createdAt, user.CreatedAt) assert.Greater(t, user.UpdatedAt, updatedAt) updatedAt = user.UpdatedAt time.Sleep(10 * time.Millisecond) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.LastLogin) assert.Equal(t, int64(0), user.FirstDownload) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, createdAt, user.CreatedAt) assert.Greater(t, user.UpdatedAt, updatedAt) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestAdminTimestamps(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) createdAt := admin.CreatedAt updatedAt := admin.UpdatedAt assert.Equal(t, int64(0), admin.LastLogin) assert.Greater(t, createdAt, int64(0)) assert.Greater(t, updatedAt, int64(0)) time.Sleep(10 * time.Millisecond) admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), admin.LastLogin) assert.Equal(t, createdAt, admin.CreatedAt) assert.Greater(t, admin.UpdatedAt, updatedAt) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestHTTPUserAuthEmptyPassword(t *testing.T) { u := getTestUser() u.Password = "" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, "") c := httpclient.GetHTTPClient() resp, err := c.Do(req) c.CloseIdleConnections() assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "unexpected status code 401") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestHTTPAnonymousUser(t *testing.T) { u := getTestUser() u.Filters.IsAnonymous = true _, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.Error(t, err) user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.IsAnonymous) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) c := httpclient.GetHTTPClient() resp, err := c.Do(req) c.CloseIdleConnections() assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unexpected status code 403") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestHTTPUserAuthentication(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) c := httpclient.GetHTTPClient() resp, err := c.Do(req) c.CloseIdleConnections() assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) assert.NotEmpty(t, userToken) err = resp.Body.Close() assert.NoError(t, err) // login with wrong credentials req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, "") resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, "wrong pwd") resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) respBody, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Contains(t, string(respBody), "invalid credentials") err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.SetBasicAuth("wrong username", defaultPassword) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) respBody, err = io.ReadAll(resp.Body) assert.NoError(t, err) assert.Contains(t, string(respBody), "invalid credentials") err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultTokenAuthUser, defaultTokenAuthPass) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) responseHolder = make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) adminToken := responseHolder["access_token"].(string) assert.NotEmpty(t, adminToken) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, versionPath), nil) assert.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", adminToken)) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) // using the user token should not work req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, versionPath), nil) assert.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", userToken)) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userLogoutPath), nil) assert.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", adminToken)) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userLogoutPath), nil) assert.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", userToken)) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermMFADisabled(t *testing.T) { u := getTestUser() u.Filters.WebClient = []string{sdk.WebClientMFADisabled} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH} _, resp, err := httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) assert.Contains(t, string(resp), "you cannot require two-factor authentication and at the same time disallow it") user.Filters.TwoFactorAuthProtocols = nil configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } asJSON, err := json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // MFA is disabled for this user user.Filters.WebClient = []string{sdk.WebClientWriteDisabled} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now we cannot disable MFA for this user user.Filters.WebClient = []string{sdk.WebClientMFADisabled} _, resp, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) assert.Contains(t, string(resp), "two-factor authentication cannot be disabled for a user with an active configuration") saveReq := make(map[string]bool) saveReq["enabled"] = false asJSON, err = json.Marshal(saveReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPost, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUpdateUserPassword(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.PasswordStrength = 20 g.UserSettings.MaxSessions = 10 group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Filters.RequirePasswordChange = true u.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) lastPwdChange := user.LastPasswordChange time.Sleep(100 * time.Millisecond) newPwd := "uaCooGh3pheiShooghah" err = dataprovider.UpdateUserPassword(user.Username, newPwd, "", "", "") assert.NoError(t, err) _, err = dataprovider.CheckUserAndPass(user.Username, newPwd, "", common.ProtocolHTTP) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, user.Filters.RequirePasswordChange) assert.NotEqual(t, lastPwdChange, user.LastPasswordChange) // check that we don't save group overrides assert.Equal(t, 0, user.MaxSessions) assert.Equal(t, 0, user.Filters.PasswordStrength) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestLoginRedirectNext(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) uri := webClientFilesPath + "?path=%2F" //nolint:goconst req, err := http.NewRequest(http.MethodGet, uri, nil) assert.NoError(t, err) req.RequestURI = uri rr := executeRequest(req) checkResponseCode(t, http.StatusFound, rr) redirectURI := rr.Header().Get("Location") assert.Equal(t, webClientLoginPath+"?next="+url.QueryEscape(uri), redirectURI) //nolint:goconst // render the login page req, err = http.NewRequest(http.MethodGet, redirectURI, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", redirectURI)) // now login the user and check the redirect loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie) form := getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, redirectURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = redirectURI setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, uri, rr.Header().Get("Location")) // unsafe URI loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) unsafeURI := webClientLoginPath + "?next=" + url.QueryEscape("http://example.net") req, err = http.NewRequest(http.MethodPost, unsafeURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = unsafeURI setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) unsupportedURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientProfilePath) req, err = http.NewRequest(http.MethodPost, unsupportedURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = unsupportedURI setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMustChangePasswordRequirement(t *testing.T) { u := getTestUser() u.Filters.RequirePasswordChange = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.True(t, user.Filters.RequirePasswordChange) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, userFilesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Password change required. Please set a new password to continue to use your account") req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdRequired) // change pwd pwd := make(map[string]string) pwd["current_password"] = defaultPassword pwd["new_password"] = altAdminPassword asJSON, err := json.Marshal(pwd) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // check that the change pwd bool is changed user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, user.Filters.RequirePasswordChange) // get a new token token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, altAdminPassword) assert.NoError(t, err) webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, altAdminPassword) assert.NoError(t, err) // the new token should work req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // check the same as above but changing password from the WebClient UI user.Filters.RequirePasswordChange = true _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, altAdminPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("current_password", altAdminPassword) form.Set("new_password1", defaultPassword) form.Set("new_password2", defaultPassword) req, err = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestTwoFactorRequirements(t *testing.T) { u := getTestUser() u.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP, common.ProtocolFTP} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols") req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nError2FARequired) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolHTTP}, } asJSON, err := json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following protocols are required") userTOTPConfig.Protocols = []string{common.ProtocolHTTP, common.ProtocolFTP} asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now get new tokens and check that the two factor requirements are now met passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", passcode) req.SetBasicAuth(defaultUsername, defaultPassword) resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) assert.NotEmpty(t, userToken) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userDirsPath), nil) assert.NoError(t, err) setBearerForReq(req, userToken) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestTwoFactorRequirementsGroupLevel(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP, common.ProtocolFTP} group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nError2FARequired) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols") configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolHTTP}, } asJSON, err := json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following protocols are required") userTOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolFTP, common.ProtocolHTTP}, } asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now get new tokens and check that the two factor requirements are now met passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", passcode) req.SetBasicAuth(defaultUsername, defaultPassword) resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) assert.NotEmpty(t, userToken) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userDirsPath), nil) assert.NoError(t, err) setBearerForReq(req, userToken) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestAdminMustChangePasswordRequirement(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin.Filters.RequirePasswordChange = true admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) assert.NoError(t, err) httpdtest.SetJWTToken(token) _, _, err = httpdtest.GetUsers(0, 0, http.StatusForbidden) assert.NoError(t, err) _, _, err = httpdtest.GetStatus(http.StatusForbidden) assert.NoError(t, err) _, err = httpdtest.ChangeAdminPassword(altAdminPassword, defaultTokenAuthPass, http.StatusOK) assert.NoError(t, err) httpdtest.SetJWTToken("") admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.RequirePasswordChange) // get a new token token, _, err = httpdtest.GetToken(altAdminUsername, defaultTokenAuthPass) assert.NoError(t, err) httpdtest.SetJWTToken(token) _, _, err = httpdtest.GetUsers(0, 0, http.StatusOK) assert.NoError(t, err) desc := xid.New().String() admin.Filters.RequirePasswordChange = true admin.Filters.RequireTwoFactor = true admin.Description = desc _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) if assert.Error(t, err) { assert.ErrorContains(t, err, "require password change mismatch") } admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.RequirePasswordChange) assert.False(t, admin.Filters.RequireTwoFactor) assert.Equal(t, desc, admin.Description) httpdtest.SetJWTToken("") admin.Filters.RequirePasswordChange = true _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) // test the same for the WebAdmin webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) req.RequestURI = webUsersPath setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // The change password page should be accessible, we get the CSRF from it. csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeAdminPwdPath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("current_password", defaultTokenAuthPass) form.Set("new_password1", altAdminPassword) form.Set("new_password2", altAdminPassword) req, err = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.RequirePasswordChange) webToken, err = getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) req.RequestURI = webUsersPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestAdminTwoFactorRequirements(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin.Filters.RequireTwoFactor = true admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, serverStatusPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Two-factor authentication requirements not met") webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webFoldersPath, nil) assert.NoError(t, err) req.RequestURI = webFoldersPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nError2FARequiredGeneric) // add TOTP config configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], altAdminUsername) assert.NoError(t, err) adminTOTPConfig := dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), } asJSON, err := json.Marshal(adminTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", passcode) req.SetBasicAuth(altAdminUsername, altAdminPassword) resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) token = responseHolder["access_token"].(string) assert.NotEmpty(t, token) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, serverStatusPath), nil) assert.NoError(t, err) setBearerForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) // try to disable 2FA disableReq := map[string]any{ "enabled": false, } asJSON, err = json.Marshal(disableReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, fmt.Sprintf("%v%v", httpBaseURL, adminTOTPSavePath), bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusBadRequest, resp.StatusCode) bodyResp, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Contains(t, string(bodyResp), "two-factor authentication must be enabled") err = resp.Body.Close() assert.NoError(t, err) // try to disable 2FA using the dedicated API req, err = http.NewRequest(http.MethodPut, fmt.Sprintf("%v%v", httpBaseURL, path.Join(adminPath, altAdminUsername, "2fa", "disable")), nil) assert.NoError(t, err) setBearerForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusBadRequest, resp.StatusCode) bodyResp, err = io.ReadAll(resp.Body) assert.NoError(t, err) assert.Contains(t, string(bodyResp), "two-factor authentication must be enabled") err = resp.Body.Close() assert.NoError(t, err) // disabling 2FA using another admin should work token, err = getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername, "2fa", "disable"), nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // check admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.TOTPConfig.Enabled) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestLoginUserAPITOTP(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolHTTP}, } asJSON, err := json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now require HTTP and SSH for TOTP user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP, common.ProtocolSSH} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // two factor auth cannot be disabled config := make(map[string]any) config["enabled"] = false asJSON, err = json.Marshal(config) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "two-factor authentication must be enabled") // all the required protocols must be enabled asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following protocols are required") // setting all the required protocols should work userTOTPConfig.Protocols = []string{common.ProtocolHTTP, common.ProtocolSSH} asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", passcode) req.SetBasicAuth(defaultUsername, defaultPassword) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) assert.NotEmpty(t, userToken) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", passcode) req.SetBasicAuth(defaultUsername, defaultPassword) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginAdminAPITOTP(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], admin.Username) assert.NoError(t, err) altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) adminTOTPConfig := dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), } asJSON, err := json.Marshal(adminTOTPConfig) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, altToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) assert.Len(t, admin.Filters.RecoveryCodes, 12) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(altAdminUsername, altAdminPassword) resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", "passcode") req.SetBasicAuth(altAdminUsername, altAdminPassword) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", passcode) req.SetBasicAuth(altAdminUsername, altAdminPassword) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) adminToken := responseHolder["access_token"].(string) assert.NotEmpty(t, adminToken) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, versionPath), nil) assert.NoError(t, err) setBearerForReq(req, adminToken) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) // get/set recovery codes req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // disable two-factor auth saveReq := make(map[string]bool) saveReq["enabled"] = false asJSON, err = json.Marshal(saveReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.TOTPConfig.Enabled) assert.Len(t, admin.Filters.RecoveryCodes, 0) // get/set recovery codes will not work req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestHTTPStreamZipError(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) assert.NotEmpty(t, userToken) err = resp.Body.Close() assert.NoError(t, err) filesList := []string{"missing"} asJSON, err := json.Marshal(filesList) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, fmt.Sprintf("%v%v", httpBaseURL, userStreamZipPath), bytes.NewBuffer(asJSON)) assert.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", userToken)) resp, err = httpclient.GetHTTPClient().Do(req) if !assert.Error(t, err) { // the connection will be closed err = resp.Body.Close() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestBasicAdminHandling(t *testing.T) { // we have one admin by default admins, _, err := httpdtest.GetAdmins(0, 0, http.StatusOK) assert.NoError(t, err) assert.GreaterOrEqual(t, len(admins), 1) admin := getTestAdmin() // the default admin already exists _, _, err = httpdtest.AddAdmin(admin, http.StatusConflict) assert.NoError(t, err) admin.Username = altAdminUsername admin.Filters.Preferences.HideUserPageSections = 1 + 4 + 8 admin.Filters.Preferences.DefaultUsersExpiration = 30 admin, _, err = httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.Preferences.HideGroups()) assert.False(t, admin.Filters.Preferences.HideFilesystem()) assert.True(t, admin.Filters.Preferences.HideVirtualFolders()) assert.True(t, admin.Filters.Preferences.HideProfile()) assert.False(t, admin.Filters.Preferences.HideACLs()) assert.False(t, admin.Filters.Preferences.HideDiskQuotaAndBandwidthLimits()) assert.False(t, admin.Filters.Preferences.HideAdvancedSettings()) admin.AdditionalInfo = "test info" admin.Filters.Preferences.HideUserPageSections = 16 + 32 + 64 admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) assert.Equal(t, "test info", admin.AdditionalInfo) assert.False(t, admin.Filters.Preferences.HideGroups()) assert.False(t, admin.Filters.Preferences.HideFilesystem()) assert.False(t, admin.Filters.Preferences.HideVirtualFolders()) assert.False(t, admin.Filters.Preferences.HideProfile()) assert.True(t, admin.Filters.Preferences.HideACLs()) assert.True(t, admin.Filters.Preferences.HideDiskQuotaAndBandwidthLimits()) assert.True(t, admin.Filters.Preferences.HideAdvancedSettings()) admins, _, err = httpdtest.GetAdmins(1, 0, http.StatusOK) assert.NoError(t, err) assert.Len(t, admins, 1) assert.NotEqual(t, admin.Username, admins[0].Username) admins, _, err = httpdtest.GetAdmins(1, 1, http.StatusOK) assert.NoError(t, err) assert.Len(t, admins, 1) assert.Equal(t, admin.Username, admins[0].Username) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusNotFound) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username+"123", http.StatusNotFound) assert.NoError(t, err) admin.Username = defaultTokenAuthUser _, err = httpdtest.RemoveAdmin(admin, http.StatusBadRequest) assert.NoError(t, err) } func TestAdminGroups(t *testing.T) { group1 := getTestGroup() group1.Name += "_1" group1, _, err := httpdtest.AddGroup(group1, http.StatusCreated) assert.NoError(t, err) group2 := getTestGroup() group2.Name += "_2" group2, _, err = httpdtest.AddGroup(group2, http.StatusCreated) assert.NoError(t, err) group3 := getTestGroup() group3.Name += "_3" group3, _, err = httpdtest.AddGroup(group3, http.StatusCreated) assert.NoError(t, err) a := getTestAdmin() a.Username = altAdminUsername a.Groups = []dataprovider.AdminGroupMapping{ { Name: group1.Name, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, }, }, { Name: group2.Name, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, }, }, { Name: group3.Name, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsMembership, }, }, } admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) assert.Len(t, admin.Groups, 3) groups, _, err := httpdtest.GetGroups(0, 0, http.StatusOK) assert.NoError(t, err) assert.Len(t, groups, 3) for _, g := range groups { if assert.Len(t, g.Admins, 1) { assert.Equal(t, admin.Username, g.Admins[0]) } } admin, _, err = httpdtest.UpdateAdmin(a, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group2, http.StatusOK) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.Len(t, admin.Groups, 1) // try to add a missing group admin.Groups = []dataprovider.AdminGroupMapping{ { Name: group1.Name, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, }, }, { Name: group2.Name, Options: dataprovider.AdminGroupMappingOptions{ AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, }, }, } group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group3.Admins, 1) _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.Error(t, err) group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group3.Admins, 1) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group3.Admins, 0) _, err = httpdtest.RemoveGroup(group3, http.StatusOK) assert.NoError(t, err) } func TestChangeAdminPassword(t *testing.T) { _, err := httpdtest.ChangeAdminPassword("wrong", defaultTokenAuthPass, http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.ChangeAdminPassword(defaultTokenAuthPass, defaultTokenAuthPass, http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.ChangeAdminPassword(defaultTokenAuthPass, defaultTokenAuthPass+"1", http.StatusOK) assert.NoError(t, err) _, err = httpdtest.ChangeAdminPassword(defaultTokenAuthPass+"1", defaultTokenAuthPass, http.StatusUnauthorized) assert.NoError(t, err) admin, err := dataprovider.AdminExists(defaultTokenAuthUser) assert.NoError(t, err) admin.Password = defaultTokenAuthPass err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) } func TestPasswordValidations(t *testing.T) { if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { t.Skip("this test is not supported with the memory provider") } err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") providerConf := config.GetProviderConf() assert.NoError(t, err) providerConf.PasswordValidation.Admins.MinEntropy = 50 providerConf.PasswordValidation.Users.MinEntropy = 70 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword _, resp, err := httpdtest.AddAdmin(a, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "insecure password") _, resp, err = httpdtest.AddUser(getTestUser(), http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "insecure password") err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestAdminPasswordHashing(t *testing.T) { if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { t.Skip("this test is not supported with the memory provider") } err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") providerConf := config.GetProviderConf() assert.NoError(t, err) providerConf.PasswordHashing.Algo = dataprovider.HashingAlgoArgon2ID err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) currentAdmin, err := dataprovider.AdminExists(defaultTokenAuthUser) assert.NoError(t, err) assert.True(t, strings.HasPrefix(currentAdmin.Password, "$2a$")) a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) newAdmin, err := dataprovider.AdminExists(altAdminUsername) assert.NoError(t, err) assert.True(t, strings.HasPrefix(newAdmin.Password, "$argon2id$")) token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) assert.NoError(t, err) httpdtest.SetJWTToken(token) _, _, err = httpdtest.GetStatus(http.StatusOK) assert.NoError(t, err) httpdtest.SetJWTToken("") _, _, err = httpdtest.GetStatus(http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath assert.NoError(t, err) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestDefaultUsersExpiration(t *testing.T) { a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword a.Filters.Preferences.DefaultUsersExpiration = 30 admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) assert.NoError(t, err) httpdtest.SetJWTToken(token) _, _, err = httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.Error(t, err) user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.ExpirationDate, int64(0)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) u := getTestUser() u.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)) _, _, err = httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, u.ExpirationDate, user.ExpirationDate) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) httpdtest.SetJWTToken("") _, _, err = httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) // render the user template page webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webTemplateUser, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webTemplateUser+fmt.Sprintf("?from=%s", user.Username), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) httpdtest.SetJWTToken(token) _, _, err = httpdtest.AddUser(u, http.StatusNotFound) assert.NoError(t, err) httpdtest.SetJWTToken("") } func TestAdminInvalidCredentials(t *testing.T) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) assert.NoError(t, err) req.SetBasicAuth(defaultTokenAuthUser, defaultTokenAuthPass) resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) // wrong password req.SetBasicAuth(defaultTokenAuthUser, "wrong pwd") resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) err = resp.Body.Close() assert.NoError(t, err) assert.Equal(t, dataprovider.ErrInvalidCredentials.Error(), responseHolder["error"].(string)) // wrong username req.SetBasicAuth("wrong username", defaultTokenAuthPass) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) responseHolder = make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) err = resp.Body.Close() assert.NoError(t, err) assert.Equal(t, dataprovider.ErrInvalidCredentials.Error(), responseHolder["error"].(string)) } func TestAdminLastLogin(t *testing.T) { a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, int64(0), admin.LastLogin) _, _, err = httpdtest.GetToken(altAdminUsername, altAdminPassword) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.Greater(t, admin.LastLogin, int64(0)) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestAdminAllowList(t *testing.T) { a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) assert.NoError(t, err) httpdtest.SetJWTToken(token) _, _, err = httpdtest.GetStatus(http.StatusOK) assert.NoError(t, err) httpdtest.SetJWTToken("") admin.Password = altAdminPassword admin.Filters.AllowList = []string{"10.6.6.0/32"} admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetToken(altAdminUsername, altAdminPassword) assert.EqualError(t, err, "wrong status code: got 401 want 200") _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestUserStatus(t *testing.T) { u := getTestUser() u.Status = 3 _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Status = 0 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Status = 2 _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) user.Status = 1 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUidGidLimits(t *testing.T) { u := getTestUser() u.UID = math.MaxInt32 u.GID = math.MaxInt32 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, math.MaxInt32, user.GetUID()) assert.Equal(t, math.MaxInt32, user.GetGID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestAddUserNoCredentials(t *testing.T) { u := getTestUser() u.Password = "" u.PublicKeys = []string{} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) // this user cannot login with an empty password but it still can use an SSH cert _, err = getJWTAPITokenFromTestServer(defaultTokenAuthUser, "") assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestAddUserNoUsername(t *testing.T) { u := getTestUser() u.Username = "" _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) } func TestAddUserNoHomeDir(t *testing.T) { u := getTestUser() u.HomeDir = "" _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) } func TestAddUserInvalidHomeDir(t *testing.T) { u := getTestUser() u.HomeDir = "relative_path" //nolint:goconst _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) } func TestAddUserNoPerms(t *testing.T) { u := getTestUser() u.Permissions = make(map[string][]string) _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Permissions["/"] = []string{} _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) } func TestAddUserInvalidEmail(t *testing.T) { u := getTestUser() u.Email = "invalid_email" _, body, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(body), "Validation error: email") } func TestAddUserInvalidPerms(t *testing.T) { u := getTestUser() u.Permissions["/"] = []string{"invalidPerm"} _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) // permissions for root dir are mandatory u.Permissions["/"] = []string{} u.Permissions["/somedir"] = []string{dataprovider.PermAny} _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Permissions["/"] = []string{dataprovider.PermAny} u.Permissions["/subdir/.."] = []string{dataprovider.PermAny} _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) } func TestAddUserInvalidFilters(t *testing.T) { u := getTestUser() u.Filters.AllowedIP = []string{"192.168.1.0/24", "192.168.2.0"} _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.AllowedIP = []string{} u.Filters.DeniedIP = []string{"192.168.3.0/16", "invalid"} _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.DeniedIP = []string{} u.Filters.DeniedLoginMethods = []string{"invalid"} _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.DeniedLoginMethods = dataprovider.ValidLoginMethods _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificateAndPwd} u.Filters.DeniedProtocols = dataprovider.ValidProtocols _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.DeniedProtocols = []string{common.ProtocolFTP} u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "relative", AllowedPatterns: []string{}, DeniedPatterns: []string{}, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", AllowedPatterns: []string{}, DeniedPatterns: []string{}, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/subdir", AllowedPatterns: []string{"*.zip"}, DeniedPatterns: []string{}, }, { Path: "/subdir", AllowedPatterns: []string{"*.rar"}, DeniedPatterns: []string{"*.jpg"}, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "relative", AllowedPatterns: []string{}, DeniedPatterns: []string{}, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", AllowedPatterns: []string{}, DeniedPatterns: []string{}, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/subdir", AllowedPatterns: []string{"*.zip"}, }, { Path: "/subdir", AllowedPatterns: []string{"*.rar"}, DeniedPatterns: []string{"*.jpg"}, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/subdir", AllowedPatterns: []string{"a\\"}, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/subdir", AllowedPatterns: []string{"*.*"}, DenyPolicy: 100, }, } _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.DeniedProtocols = []string{"invalid"} _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.DeniedProtocols = dataprovider.ValidProtocols _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.DeniedProtocols = nil u.Filters.TLSUsername = "not a supported attribute" _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.Filters.TLSUsername = "" u.Filters.WebClient = []string{"not a valid web client options"} _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) } func TestAddUserInvalidFsConfig(t *testing.T) { u := getTestUser() u.FsConfig.Provider = sdk.S3FilesystemProvider u.FsConfig.S3Config.Bucket = "" _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.S3Config.Bucket = "testbucket" u.FsConfig.S3Config.Region = "eu-west-1" //nolint:goconst u.FsConfig.S3Config.AccessKey = "access-key" //nolint:goconst u.FsConfig.S3Config.AccessSecret = kms.NewSecret(sdkkms.SecretStatusRedacted, "access-secret", "", "") u.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000/path?a=b" u.FsConfig.S3Config.StorageClass = "Standard" //nolint:goconst u.FsConfig.S3Config.KeyPrefix = "/adir/subdir/" _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.S3Config.AccessSecret.SetStatus(sdkkms.SecretStatusPlain) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.S3Config.KeyPrefix = "" u.FsConfig.S3Config.UploadPartSize = 3 _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.S3Config.UploadPartSize = 5001 _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.S3Config.UploadPartSize = 0 u.FsConfig.S3Config.UploadConcurrency = -1 _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.S3Config.UploadConcurrency = 0 u.FsConfig.S3Config.DownloadPartSize = -1 _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "download_part_size cannot be") } u.FsConfig.S3Config.DownloadPartSize = 5001 _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "download_part_size cannot be") } u.FsConfig.S3Config.DownloadPartSize = 0 u.FsConfig.S3Config.DownloadConcurrency = 100 _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid download concurrency") } u.FsConfig.S3Config.DownloadConcurrency = -1 _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid download concurrency") } u.FsConfig.S3Config.DownloadConcurrency = 0 u.FsConfig.S3Config.Endpoint = "" u.FsConfig.S3Config.Region = "" _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "region cannot be empty") } u = getTestUser() u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "" _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.GCSConfig.Bucket = "abucket" u.FsConfig.GCSConfig.StorageClass = "Standard" u.FsConfig.GCSConfig.KeyPrefix = "/somedir/subdir/" u.FsConfig.GCSConfig.Credentials = kms.NewSecret(sdkkms.SecretStatusRedacted, "test", "", "") //nolint:goconst _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.GCSConfig.Credentials.SetStatus(sdkkms.SecretStatusPlain) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.GCSConfig.KeyPrefix = "somedir/subdir/" //nolint:goconst u.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret() u.FsConfig.GCSConfig.AutomaticCredentials = 0 _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.GCSConfig.Credentials = kms.NewSecret(sdkkms.SecretStatusSecretBox, "invalid", "", "") _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u = getTestUser() u.FsConfig.Provider = sdk.AzureBlobFilesystemProvider u.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("http://foo\x7f.com/") _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.AzBlobConfig.SASURL = kms.NewSecret(sdkkms.SecretStatusRedacted, "key", "", "") _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.AzBlobConfig.SASURL = kms.NewEmptySecret() u.FsConfig.AzBlobConfig.AccountName = "name" _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.AzBlobConfig.AccountKey = kms.NewSecret(sdkkms.SecretStatusRedacted, "key", "", "") u.FsConfig.AzBlobConfig.KeyPrefix = "/amedir/subdir/" _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.AzBlobConfig.AccountKey.SetStatus(sdkkms.SecretStatusPlain) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.AzBlobConfig.KeyPrefix = "amedir/subdir/" u.FsConfig.AzBlobConfig.UploadPartSize = -1 _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.AzBlobConfig.UploadPartSize = 101 _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u = getTestUser() u.FsConfig.Provider = sdk.CryptedFilesystemProvider _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.CryptConfig.Passphrase = kms.NewSecret(sdkkms.SecretStatusRedacted, "akey", "", "") _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u = getTestUser() u.FsConfig.Provider = sdk.SFTPFilesystemProvider _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.SFTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusRedacted, "randompkey", "", "") _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() u.FsConfig.SFTPConfig.PrivateKey = kms.NewSecret(sdkkms.SecretStatusRedacted, "keyforpkey", "", "") _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret("pk") u.FsConfig.SFTPConfig.Endpoint = "127.1.1.1:22" u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.BufferSize = -1 _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid buffer_size") } u.FsConfig.SFTPConfig.BufferSize = 1000 _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid buffer_size") } u = getTestUser() u.FsConfig.Provider = sdk.HTTPFilesystemProvider u.FsConfig.HTTPConfig.Endpoint = "http://foo\x7f.com/" _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid endpoint") } u.FsConfig.HTTPConfig.Endpoint = "http://127.0.0.1:9999/api/v1" u.FsConfig.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, "", "", "") _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid encrypted password") } u.FsConfig.HTTPConfig.Password = nil u.FsConfig.HTTPConfig.APIKey = kms.NewSecret(sdkkms.SecretStatusRedacted, redactedSecret, "", "") _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "cannot save a user with a redacted secret") } u.FsConfig.HTTPConfig.APIKey = nil u.FsConfig.HTTPConfig.Endpoint = "/api/v1" _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid endpoint schema") } u.FsConfig.HTTPConfig.Endpoint = "http://unix?api_prefix=v1" _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid unix domain socket path") } u.FsConfig.HTTPConfig.Endpoint = "http://unix?socket_path=test.sock" _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) if assert.NoError(t, err) { assert.Contains(t, string(resp), "invalid unix domain socket path") } } func TestUserRedactedPassword(t *testing.T) { u := getTestUser() u.FsConfig.Provider = sdk.S3FilesystemProvider u.FsConfig.S3Config.Bucket = "b" u.FsConfig.S3Config.Region = "eu-west-1" u.FsConfig.S3Config.AccessKey = "access-key" u.FsConfig.S3Config.RoleARN = "myRoleARN" u.FsConfig.S3Config.AccessSecret = kms.NewSecret(sdkkms.SecretStatusRedacted, "access-secret", "", "") u.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000/path?k=m" u.FsConfig.S3Config.StorageClass = "Standard" u.FsConfig.S3Config.ACL = "bucket-owner-full-control" _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "cannot save a user with a redacted secret") err = dataprovider.AddUser(&u, "", "", "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "cannot save a user with a redacted secret") } u.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("secret") u.FsConfig.S3Config.SSECustomerKey = kms.NewSecret(sdkkms.SecretStatusRedacted, "mysecretkey", "", "") _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "cannot save a user with a redacted secret") u.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("key") user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) folderName := "folderName" vfolder := vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, MappedPath: filepath.Join(os.TempDir(), "crypted"), FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewSecret(sdkkms.SecretStatusRedacted, "crypted-secret", "", ""), }, }, }, VirtualPath: "/avpath", } user.Password = defaultPassword user.VirtualFolders = append(user.VirtualFolders, vfolder) err = dataprovider.UpdateUser(&user, "", "", "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "cannot save a user with a redacted secret") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUserType(t *testing.T) { u := getTestUser() u.Filters.UserType = string(sdk.UserTypeLDAP) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, string(sdk.UserTypeLDAP), user.Filters.UserType) user.Filters.UserType = string(sdk.UserTypeOS) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, string(sdk.UserTypeOS), user.Filters.UserType) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestRetentionAPI(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) t.Cleanup(func() { _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) }) checks, _, err := httpdtest.GetRetentionChecks(http.StatusOK) assert.NoError(t, err) assert.Len(t, checks, 0) localFilePath := filepath.Join(user.HomeDir, "testdir", "testfile") err = os.MkdirAll(filepath.Dir(localFilePath), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(localFilePath, []byte("test data"), os.ModePerm) assert.NoError(t, err) folderRetention := []dataprovider.FolderRetention{ { Path: "/", Retention: 24, DeleteEmptyDirs: true, }, } check := common.RetentionCheck{ Folders: folderRetention, } c := common.RetentionChecks.Add(check, &user) require.NotNil(t, c) err = c.Start() require.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.FileExists(t, localFilePath) err = os.Chtimes(localFilePath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) assert.NoError(t, err) err = c.Start() require.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.NoFileExists(t, localFilePath) assert.NoDirExists(t, filepath.Dir(localFilePath)) c = common.RetentionChecks.Add(check, &user) assert.NotNil(t, c) assert.Nil(t, common.RetentionChecks.Add(check, &user)) // a check for this user is already in progress checks, _, err = httpdtest.GetRetentionChecks(http.StatusOK) assert.NoError(t, err) assert.Len(t, checks, 1) err = c.Start() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.RetentionChecks.Get("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) checks, _, err = httpdtest.GetRetentionChecks(http.StatusOK) assert.NoError(t, err) assert.Len(t, checks, 0) } func TestAddUserInvalidVirtualFolders(t *testing.T) { u := getTestUser() folderName := "fname" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), Name: folderName, }, VirtualPath: "/vdir", }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), Name: folderName + "1", }, VirtualPath: "/vdir", // invalid, already defined }) _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), Name: folderName, }, VirtualPath: "/vdir1", }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), Name: folderName, // invalid, unique constraint (user.id, folder.id) violated }, VirtualPath: "/vdir2", }) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), Name: folderName + "1", }, VirtualPath: "/vdir1/", QuotaSize: -1, QuotaFiles: 1, // invvalid, we cannot have -1 and > 0 }) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), Name: folderName + "1", }, VirtualPath: "/vdir1/", QuotaSize: 1, QuotaFiles: -1, }) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), Name: folderName + "1", }, VirtualPath: "/vdir1/", QuotaSize: -2, // invalid QuotaFiles: 0, }) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), Name: folderName + "1", }, VirtualPath: "/vdir1/", QuotaSize: 0, QuotaFiles: -2, // invalid }) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), }, VirtualPath: "/vdir1", }) // folder name is mandatory _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) } func TestUserPublicKey(t *testing.T) { u := getTestUser() u.Password = "" invalidPubKey := "invalid" u.PublicKeys = []string{invalidPubKey} _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.PublicKeys = []string{testPubKey} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) dbUser, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) assert.Empty(t, dbUser.Password) assert.False(t, dbUser.IsPasswordHashed()) user.PublicKeys = []string{testPubKey, invalidPubKey} _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) user.PublicKeys = []string{testPubKey, testPubKey, testPubKey} user.Password = defaultPassword _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) dbUser, err = dataprovider.UserExists(u.Username, "") assert.NoError(t, err) assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) // DSA keys are not accepted u = getTestUser() u.Password = "" u.PublicKeys = []string{"ssh-dss AAAAB3NzaC1kc3MAAACBAK+BKLZs1Vd0cWYOquKfp++0ml9hkzB7UDRozT3nhRcyHcuwASsXiVTqsg96oGjBcUUy076CXlsfJEXE2P0dF6tt1wvABPMwKpOn+kIrfJ0j93X2c2KIZNlD4YuNUJjLHu1DvgQHw8NMps6l5D0M5NFCRdD3NYhI5zFVJJ4CzikrAAAAFQCRBagw7gEbs0gd8So7OLMcSVzs/wAAAIBjuo7U9q8npchQ3otgCvj0xIwsQ+Fi9bH0SBceqbCcVzFYY6JXSQ0XmwHs+0AuvRCPIGaBdfcm+w+9YOxREtdEVjcmkYlfJpTaVljjWcWFWTQddbiamZhQ/xLU9CNLK4oYLwIGLZjCcG7nRDdLtLQdBFuzP/faEi3TD2BK114QmAAAAIEAj1n34pH2WKwbSZhzmz/OG0VzqJICFWboiM44LZl2AqcRBvEEycdHlGe2IKaj5lEtLgBKJt9NSFhBIzWh7gcEzSMlkiDecdYSFlDc4snmTiXaoiIehV59nTY6gc8GLWCzuem+WdHxvJ4yOSWF9k+a+Y+/v/35shNLkfokViOlN7k="} _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "DSA key format is insecure and it is not allowed") } func TestUpdateUserEmptyPassword(t *testing.T) { u := getTestUser() u.PublicKeys = []string{testPubKey} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) // the password is not empty dbUser, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) // now update the user and set an empty password data, err := json.Marshal(dbUser) assert.NoError(t, err) var customUser map[string]any err = json.Unmarshal(data, &customUser) assert.NoError(t, err) customUser["password"] = "" asJSON, err := json.Marshal(customUser) assert.NoError(t, err) userNoPwd, _, err := httpdtest.UpdateUserWithJSON(user, http.StatusOK, "", asJSON) assert.NoError(t, err) assert.Equal(t, user.Password, userNoPwd.Password) // the password is hidden // check the password within the data provider dbUser, err = dataprovider.UserExists(u.Username, "") assert.NoError(t, err) assert.Empty(t, dbUser.Password) assert.False(t, dbUser.IsPasswordHashed()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUpdateUserNoPassword(t *testing.T) { u := getTestUser() u.PublicKeys = []string{testPubKey} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) // the password is not empty dbUser, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) // now update the user and remove the password field, old password should be preserved user.Password = "" // password has the omitempty tag _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // the password is preserved dbUser, err = dataprovider.UserExists(u.Username, "") assert.NoError(t, err) assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUpdateUser(t *testing.T) { u := getTestUser() u.UsedQuotaFiles = 1 u.UsedQuotaSize = 2 u.Filters.TLSUsername = sdk.TLSUsernameCN u.Filters.Hooks.CheckPasswordDisabled = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) user.HomeDir = filepath.Join(homeBasePath, "testmod") user.UID = 33 user.GID = 101 user.MaxSessions = 10 user.QuotaSize = 4096 user.QuotaFiles = 2 user.Permissions["/"] = []string{dataprovider.PermCreateDirs, dataprovider.PermDelete, dataprovider.PermDownload} user.Permissions["/subdir"] = []string{dataprovider.PermListItems, dataprovider.PermUpload} user.Filters.AllowedIP = []string{"192.168.1.0/24", "192.168.2.0/24"} user.Filters.DeniedIP = []string{"192.168.3.0/24", "192.168.4.0/24"} user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} user.Filters.DeniedProtocols = []string{common.ProtocolWebDAV} user.Filters.TLSUsername = sdk.TLSUsernameNone user.Filters.Hooks.ExternalAuthDisabled = true user.Filters.Hooks.PreLoginDisabled = true user.Filters.Hooks.CheckPasswordDisabled = false user.Filters.DisableFsChecks = true user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ Path: "/subdir", AllowedPatterns: []string{"*.zip", "*.rar"}, DeniedPatterns: []string{"*.jpg", "*.png"}, DenyPolicy: sdk.DenyPolicyHide, }) user.Filters.MaxUploadFileSize = 4096 user.UploadBandwidth = 1024 user.DownloadBandwidth = 512 user.VirtualFolders = nil mappedPath1 := filepath.Join(os.TempDir(), "mapped_dir1") mappedPath2 := filepath.Join(os.TempDir(), "mapped_dir2") folderName1 := filepath.Base(mappedPath1) folderName2 := filepath.Base(mappedPath2) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: "/vdir1", }) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: "/vdir12/subdir", QuotaSize: 123, QuotaFiles: 2, }) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "invalid") assert.NoError(t, err) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "0") assert.NoError(t, err) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "1") assert.NoError(t, err) user.Permissions["/subdir"] = []string{} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Len(t, user.Permissions["/subdir"], 0) assert.Len(t, user.VirtualFolders, 2) for _, folder := range user.VirtualFolders { assert.Greater(t, folder.ID, int64(0)) if folder.VirtualPath == "/vdir12/subdir" { assert.Equal(t, int64(123), folder.QuotaSize) assert.Equal(t, 2, folder.QuotaFiles) } } folder, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Contains(t, folder.Users, user.Username) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) // removing the user must remove folder mapping folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 0) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 0) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) } func TestUpdateUserTransferQuotaUsage(t *testing.T) { u := getTestUser() usedDownloadDataTransfer := int64(2 * 1024 * 1024) usedUploadDataTransfer := int64(1024 * 1024) u.UsedDownloadDataTransfer = usedDownloadDataTransfer u.UsedUploadDataTransfer = usedUploadDataTransfer user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.UsedUploadDataTransfer) assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) _, err = httpdtest.UpdateTransferQuotaUsage(u, "invalid_mode", http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, usedUploadDataTransfer, user.UsedUploadDataTransfer) assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusBadRequest) assert.NoError(t, err, "user has no transfer quota restrictions add mode should fail") user.TotalDataTransfer = 100 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2*usedUploadDataTransfer, user.UsedUploadDataTransfer) assert.Equal(t, 2*usedDownloadDataTransfer, user.UsedDownloadDataTransfer) u.UsedDownloadDataTransfer = -1 _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusBadRequest) assert.NoError(t, err) u.UsedDownloadDataTransfer = usedDownloadDataTransfer u.Username += "1" _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusNotFound) assert.NoError(t, err) u.Username = defaultUsername _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, usedUploadDataTransfer, user.UsedUploadDataTransfer) assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) u.UsedDownloadDataTransfer = 0 u.UsedUploadDataTransfer = 1 _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, usedUploadDataTransfer+1, user.UsedUploadDataTransfer) assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) u.UsedDownloadDataTransfer = 1 u.UsedUploadDataTransfer = 0 _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, usedUploadDataTransfer+1, user.UsedUploadDataTransfer) assert.Equal(t, usedDownloadDataTransfer+1, user.UsedDownloadDataTransfer) token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "transfer-usage"), bytes.NewBuffer([]byte(`not a json`))) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUpdateUserQuotaUsage(t *testing.T) { u := getTestUser() usedQuotaFiles := 1 usedQuotaSize := int64(65535) u.UsedQuotaFiles = usedQuotaFiles u.UsedQuotaSize = usedQuotaSize user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) _, err = httpdtest.UpdateQuotaUsage(u, "invalid_mode", http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.UpdateQuotaUsage(u, "", http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, usedQuotaSize, user.UsedQuotaSize) _, err = httpdtest.UpdateQuotaUsage(u, "add", http.StatusBadRequest) assert.NoError(t, err, "user has no quota restrictions add mode should fail") user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, usedQuotaSize, user.UsedQuotaSize) user.QuotaFiles = 100 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = httpdtest.UpdateQuotaUsage(u, "add", http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2*usedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, 2*usedQuotaSize, user.UsedQuotaSize) u.UsedQuotaFiles = -1 _, err = httpdtest.UpdateQuotaUsage(u, "", http.StatusBadRequest) assert.NoError(t, err) u.UsedQuotaFiles = usedQuotaFiles u.Username = u.Username + "1" _, err = httpdtest.UpdateQuotaUsage(u, "", http.StatusNotFound) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUserFolderMapping(t *testing.T) { mappedPath1 := filepath.Join(os.TempDir(), "mapped_dir1") mappedPath2 := filepath.Join(os.TempDir(), "mapped_dir2") folderName1 := filepath.Base(mappedPath1) folderName2 := filepath.Base(mappedPath2) u1 := getTestUser() u1.VirtualFolders = append(u1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: "/vdir", QuotaSize: -1, QuotaFiles: -1, }) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, UsedQuotaFiles: 2, UsedQuotaSize: 123, LastQuotaUpdate: 456, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) assert.NoError(t, err) // virtual folder must be auto created folder, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Contains(t, folder.Users, user1.Username) assert.Equal(t, 2, folder.UsedQuotaFiles) assert.Equal(t, int64(123), folder.UsedQuotaSize) assert.Equal(t, int64(456), folder.LastQuotaUpdate) assert.Equal(t, 2, user1.VirtualFolders[0].UsedQuotaFiles) assert.Equal(t, int64(123), user1.VirtualFolders[0].UsedQuotaSize) assert.Equal(t, int64(456), user1.VirtualFolders[0].LastQuotaUpdate) u2 := getTestUser() u2.Username = defaultUsername + "2" u2.VirtualFolders = append(u2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, }, VirtualPath: "/vdir1", QuotaSize: 0, QuotaFiles: 0, }) u2.VirtualFolders = append(u2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, }, VirtualPath: "/vdir2", QuotaSize: -1, QuotaFiles: -1, }) user2, _, err := httpdtest.AddUser(u2, http.StatusCreated) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Contains(t, folder.Users, user2.Username) folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 2) assert.Contains(t, folder.Users, user1.Username) assert.Contains(t, folder.Users, user2.Username) // now update user2 removing mappedPath1 user2.VirtualFolders = nil user2.VirtualFolders = append(user2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, UsedQuotaFiles: 2, UsedQuotaSize: 123, }, VirtualPath: "/vdir", QuotaSize: 0, QuotaFiles: 0, }) user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Contains(t, folder.Users, user2.Username) assert.Equal(t, 0, folder.UsedQuotaFiles) assert.Equal(t, int64(0), folder.UsedQuotaSize) folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Contains(t, folder.Users, user1.Username) // add mappedPath1 again to user2 user2.VirtualFolders = append(user2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, }, VirtualPath: "/vdir1", }) user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Contains(t, folder.Users, user2.Username) // removing virtual folders should clear relations on both side _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) user2, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user2.VirtualFolders, 1) { folder := user2.VirtualFolders[0] assert.Equal(t, mappedPath1, folder.MappedPath) assert.Equal(t, folderName1, folder.Name) } user1, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user2.VirtualFolders, 1) { folder := user2.VirtualFolders[0] assert.Equal(t, mappedPath1, folder.MappedPath) } folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 2) // removing a user should clear virtual folder mapping _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Contains(t, folder.Users, user2.Username) // removing a folder should clear mapping on the user side too _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) user2, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) assert.NoError(t, err) assert.Len(t, user2.VirtualFolders, 0) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) } func TestUserS3Config(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.FsConfig.Provider = sdk.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test" //nolint:goconst user.FsConfig.S3Config.AccessKey = "Server-Access-Key" user.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("Server-Access-Secret") user.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("SSE-encryption-key") user.FsConfig.S3Config.RoleARN = "myRoleARN" user.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000" user.FsConfig.S3Config.UploadPartSize = 8 user.FsConfig.S3Config.DownloadPartMaxTime = 60 user.FsConfig.S3Config.UploadPartMaxTime = 40 user.FsConfig.S3Config.ForcePathStyle = true user.FsConfig.S3Config.SkipTLSVerify = true user.FsConfig.S3Config.DownloadPartSize = 6 folderName := "vfolderName" user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/folderPath", }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: filepath.Join(os.TempDir(), "folderName"), FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret("Crypted-Secret"), }, }, } _, _, err = httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, body, err := httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(body)) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.S3Config.AccessSecret.GetStatus()) assert.NotEmpty(t, user.FsConfig.S3Config.AccessSecret.GetPayload()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.S3Config.SSECustomerKey.GetStatus()) assert.NotEmpty(t, user.FsConfig.S3Config.SSECustomerKey.GetPayload()) assert.Empty(t, user.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.S3Config.SSECustomerKey.GetKey()) assert.Equal(t, 60, user.FsConfig.S3Config.DownloadPartMaxTime) assert.Equal(t, 40, user.FsConfig.S3Config.UploadPartMaxTime) assert.True(t, user.FsConfig.S3Config.SkipTLSVerify) if assert.Len(t, user.VirtualFolders, 1) { folder := user.VirtualFolders[0] assert.Equal(t, sdkkms.SecretStatusSecretBox, folder.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, folder.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetKey()) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, folder.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, folder.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetKey()) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 user.VirtualFolders = nil user.FsConfig.S3Config.SSECustomerKey = kms.NewEmptySecret() secret := kms.NewSecret(sdkkms.SecretStatusSecretBox, "Server-Access-Secret", "", "") user.FsConfig.S3Config.AccessSecret = secret _, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.Error(t, err) user.FsConfig.S3Config.AccessSecret.SetStatus(sdkkms.SecretStatusPlain) user, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) initialSecretPayload := user.FsConfig.S3Config.AccessSecret.GetPayload() assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.S3Config.AccessSecret.GetStatus()) assert.NotEmpty(t, initialSecretPayload) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) user.FsConfig.Provider = sdk.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test-bucket" user.FsConfig.S3Config.Region = "us-east-1" //nolint:goconst user.FsConfig.S3Config.AccessKey = "Server-Access-Key1" user.FsConfig.S3Config.Endpoint = "http://localhost:9000" user.FsConfig.S3Config.KeyPrefix = "somedir/subdir" //nolint:goconst user.FsConfig.S3Config.UploadConcurrency = 5 user.FsConfig.S3Config.DownloadConcurrency = 4 user, bb, err := httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(bb)) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.S3Config.AccessSecret.GetStatus()) assert.Equal(t, initialSecretPayload, user.FsConfig.S3Config.AccessSecret.GetPayload()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) // test user without access key and access secret (shared config state) user.FsConfig.Provider = sdk.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "testbucket" user.FsConfig.S3Config.Region = "us-east-1" user.FsConfig.S3Config.AccessKey = "" user.FsConfig.S3Config.AccessSecret = kms.NewEmptySecret() user.FsConfig.S3Config.Endpoint = "" user.FsConfig.S3Config.KeyPrefix = "somedir/subdir" user.FsConfig.S3Config.UploadPartSize = 6 user.FsConfig.S3Config.UploadConcurrency = 4 user, body, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(body)) assert.Nil(t, user.FsConfig.S3Config.AccessSecret) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 // shared credential test for add instead of update user, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err) assert.Nil(t, user.FsConfig.S3Config.AccessSecret) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestHTTPFsConfig(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.FsConfig.Provider = sdk.HTTPFilesystemProvider user.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: "http://127.0.0.1/httpfs", Username: defaultUsername, }, Password: kms.NewPlainSecret(defaultPassword), APIKey: kms.NewPlainSecret(defaultTokenAuthUser), } user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) initialPwdPayload := user.FsConfig.HTTPConfig.Password.GetPayload() initialAPIKeyPayload := user.FsConfig.HTTPConfig.APIKey.GetPayload() assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, initialPwdPayload) assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetKey()) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.APIKey.GetStatus()) assert.NotEmpty(t, initialAPIKeyPayload) assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetKey()) user.FsConfig.HTTPConfig.Password.SetStatus(sdkkms.SecretStatusSecretBox) user.FsConfig.HTTPConfig.Password.SetAdditionalData(util.GenerateUniqueID()) user.FsConfig.HTTPConfig.Password.SetKey(util.GenerateUniqueID()) user.FsConfig.HTTPConfig.APIKey.SetStatus(sdkkms.SecretStatusSecretBox) user.FsConfig.HTTPConfig.APIKey.SetAdditionalData(util.GenerateUniqueID()) user.FsConfig.HTTPConfig.APIKey.SetKey(util.GenerateUniqueID()) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.Password.GetStatus()) assert.Equal(t, initialPwdPayload, user.FsConfig.HTTPConfig.Password.GetPayload()) assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetKey()) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.APIKey.GetStatus()) assert.Equal(t, initialAPIKeyPayload, user.FsConfig.HTTPConfig.APIKey.GetPayload()) assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetKey()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) // also test AddUser u := getTestUser() u.FsConfig.Provider = sdk.HTTPFilesystemProvider u.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: "http://127.0.0.1/httpfs", Username: defaultUsername, }, Password: kms.NewPlainSecret(defaultPassword), APIKey: kms.NewPlainSecret(defaultTokenAuthUser), } user, _, err = httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, user.FsConfig.HTTPConfig.Password.GetPayload()) assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetKey()) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.APIKey.GetStatus()) assert.NotEmpty(t, user.FsConfig.HTTPConfig.APIKey.GetPayload()) assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetKey()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUserAzureBlobConfig(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.Container = "test" user.FsConfig.AzBlobConfig.AccountName = "Server-Account-Name" user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("Server-Account-Key") user.FsConfig.AzBlobConfig.Endpoint = "http://127.0.0.1:9000" user.FsConfig.AzBlobConfig.UploadPartSize = 8 user.FsConfig.AzBlobConfig.DownloadPartSize = 6 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) initialPayload := user.FsConfig.AzBlobConfig.AccountKey.GetPayload() assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.NotEmpty(t, initialPayload) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) user.FsConfig.AzBlobConfig.AccountKey.SetStatus(sdkkms.SecretStatusSecretBox) user.FsConfig.AzBlobConfig.AccountKey.SetAdditionalData("data") user.FsConfig.AzBlobConfig.AccountKey.SetKey("fake key") user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.Equal(t, initialPayload, user.FsConfig.AzBlobConfig.AccountKey.GetPayload()) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 secret := kms.NewSecret(sdkkms.SecretStatusSecretBox, "Server-Account-Key", "", "") user.FsConfig.AzBlobConfig.AccountKey = secret _, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.Error(t, err) user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("Server-Account-Key-Test") user, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err) initialPayload = user.FsConfig.AzBlobConfig.AccountKey.GetPayload() assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.NotEmpty(t, initialPayload) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.Container = "test-container" user.FsConfig.AzBlobConfig.Endpoint = "http://localhost:9001" user.FsConfig.AzBlobConfig.KeyPrefix = "somedir/subdir" user.FsConfig.AzBlobConfig.UploadConcurrency = 5 user.FsConfig.AzBlobConfig.DownloadConcurrency = 4 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.NotEmpty(t, initialPayload) assert.Equal(t, initialPayload, user.FsConfig.AzBlobConfig.AccountKey.GetPayload()) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) // test user without access key and access secret (SAS) user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("https://myaccount.blob.core.windows.net/pictures/profile.jpg?sv=2012-02-12&st=2009-02-09&se=2009-02-10&sr=c&sp=r&si=YWJjZGVmZw%3d%3d&sig=dD80ihBh5jfNpymO5Hg1IdiJIEvHcJpCMiCMnN%2fRnbI%3d") user.FsConfig.AzBlobConfig.KeyPrefix = "somedir/subdir" user.FsConfig.AzBlobConfig.AccountName = "" user.FsConfig.AzBlobConfig.AccountKey = kms.NewEmptySecret() user.FsConfig.AzBlobConfig.UploadPartSize = 6 user.FsConfig.AzBlobConfig.UploadConcurrency = 4 user.FsConfig.AzBlobConfig.DownloadPartSize = 3 user.FsConfig.AzBlobConfig.DownloadConcurrency = 5 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Nil(t, user.FsConfig.AzBlobConfig.AccountKey) assert.NotNil(t, user.FsConfig.AzBlobConfig.SASURL) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 // sas test for add instead of update user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{ BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ Container: user.FsConfig.AzBlobConfig.Container, }, SASURL: kms.NewPlainSecret("http://127.0.0.1/fake/sass/url"), } user, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err) assert.Nil(t, user.FsConfig.AzBlobConfig.AccountKey) initialPayload = user.FsConfig.AzBlobConfig.SASURL.GetPayload() assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.SASURL.GetStatus()) assert.NotEmpty(t, initialPayload) assert.Empty(t, user.FsConfig.AzBlobConfig.SASURL.GetAdditionalData()) assert.Empty(t, user.FsConfig.AzBlobConfig.SASURL.GetKey()) user.FsConfig.AzBlobConfig.SASURL.SetStatus(sdkkms.SecretStatusSecretBox) user.FsConfig.AzBlobConfig.SASURL.SetAdditionalData("data") user.FsConfig.AzBlobConfig.SASURL.SetKey("fake key") user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.SASURL.GetStatus()) assert.Equal(t, initialPayload, user.FsConfig.AzBlobConfig.SASURL.GetPayload()) assert.Empty(t, user.FsConfig.AzBlobConfig.SASURL.GetAdditionalData()) assert.Empty(t, user.FsConfig.AzBlobConfig.SASURL.GetKey()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUserCryptFs(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.FsConfig.Provider = sdk.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("crypt passphrase") user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) initialPayload := user.FsConfig.CryptConfig.Passphrase.GetPayload() assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, initialPayload) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) user.FsConfig.CryptConfig.Passphrase.SetStatus(sdkkms.SecretStatusSecretBox) user.FsConfig.CryptConfig.Passphrase.SetAdditionalData("data") user.FsConfig.CryptConfig.Passphrase.SetKey("fake pass key") user, bb, err := httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(bb)) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.Equal(t, initialPayload, user.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 secret := kms.NewSecret(sdkkms.SecretStatusSecretBox, "invalid encrypted payload", "", "") user.FsConfig.CryptConfig.Passphrase = secret _, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.Error(t, err) user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("passphrase test") user, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err) initialPayload = user.FsConfig.CryptConfig.Passphrase.GetPayload() assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, initialPayload) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) user.FsConfig.Provider = sdk.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase.SetKey("pass") user, bb, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(bb)) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, initialPayload) assert.Equal(t, initialPayload, user.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUserSFTPFs(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.FsConfig.Provider = sdk.SFTPFilesystemProvider user.FsConfig.SFTPConfig.Endpoint = "[::1]:22:22" // invalid endpoint user.FsConfig.SFTPConfig.Username = "sftp_user" user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp_pwd") user.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(sftpPrivateKey) user.FsConfig.SFTPConfig.Fingerprints = []string{sftpPkeyFingerprint} user.FsConfig.SFTPConfig.BufferSize = 2 user.FsConfig.SFTPConfig.EqualityCheckMode = 1 _, resp, err := httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) assert.Contains(t, string(resp), "invalid endpoint") user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1" _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, "127.0.0.1:22", user.FsConfig.SFTPConfig.Endpoint) user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:2022" user.FsConfig.SFTPConfig.DisableCouncurrentReads = true user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, "/", user.FsConfig.SFTPConfig.Prefix) assert.True(t, user.FsConfig.SFTPConfig.DisableCouncurrentReads) assert.Equal(t, int64(2), user.FsConfig.SFTPConfig.BufferSize) initialPwdPayload := user.FsConfig.SFTPConfig.Password.GetPayload() initialPkeyPayload := user.FsConfig.SFTPConfig.PrivateKey.GetPayload() assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.Password.GetStatus()) assert.NotEmpty(t, initialPwdPayload) assert.Empty(t, user.FsConfig.SFTPConfig.Password.GetAdditionalData()) assert.Empty(t, user.FsConfig.SFTPConfig.Password.GetKey()) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.NotEmpty(t, initialPkeyPayload) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) user.FsConfig.SFTPConfig.Password.SetStatus(sdkkms.SecretStatusSecretBox) user.FsConfig.SFTPConfig.Password.SetAdditionalData("adata") user.FsConfig.SFTPConfig.Password.SetKey("fake pwd key") user.FsConfig.SFTPConfig.PrivateKey.SetStatus(sdkkms.SecretStatusSecretBox) user.FsConfig.SFTPConfig.PrivateKey.SetAdditionalData("adata") user.FsConfig.SFTPConfig.PrivateKey.SetKey("fake key") user.FsConfig.SFTPConfig.DisableCouncurrentReads = false user, bb, err := httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(bb)) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.Password.GetStatus()) assert.Equal(t, initialPwdPayload, user.FsConfig.SFTPConfig.Password.GetPayload()) assert.Empty(t, user.FsConfig.SFTPConfig.Password.GetAdditionalData()) assert.Empty(t, user.FsConfig.SFTPConfig.Password.GetKey()) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.Equal(t, initialPkeyPayload, user.FsConfig.SFTPConfig.PrivateKey.GetPayload()) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) assert.False(t, user.FsConfig.SFTPConfig.DisableCouncurrentReads) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 secret := kms.NewSecret(sdkkms.SecretStatusSecretBox, "invalid encrypted payload", "", "") user.FsConfig.SFTPConfig.Password = secret _, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.Error(t, err) user.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() user.FsConfig.SFTPConfig.PrivateKey = secret _, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.Error(t, err) user.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(sftpPrivateKey) user, _, err = httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err) initialPkeyPayload = user.FsConfig.SFTPConfig.PrivateKey.GetPayload() assert.Nil(t, user.FsConfig.SFTPConfig.Password) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.NotEmpty(t, initialPkeyPayload) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) user.FsConfig.Provider = sdk.SFTPFilesystemProvider user.FsConfig.SFTPConfig.PrivateKey.SetKey("k") user, bb, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(bb)) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.NotEmpty(t, initialPkeyPayload) assert.Equal(t, initialPkeyPayload, user.FsConfig.SFTPConfig.PrivateKey.GetPayload()) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUserHiddenFields(t *testing.T) { // sensitive data must be hidden but not deleted from the dataprovider usernames := []string{"user1", "user2", "user3", "user4", "user5", "user6"} u1 := getTestUser() u1.Username = usernames[0] u1.FsConfig.Provider = sdk.S3FilesystemProvider u1.FsConfig.S3Config.Bucket = "test" u1.FsConfig.S3Config.Region = "us-east-1" u1.FsConfig.S3Config.AccessKey = "S3-Access-Key" u1.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("S3-Access-Secret") u1.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("SSE-secret-key") user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) assert.NoError(t, err) u2 := getTestUser() u2.Username = usernames[1] u2.FsConfig.Provider = sdk.GCSFilesystemProvider u2.FsConfig.GCSConfig.Bucket = "test" u2.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("fake credentials") u2.FsConfig.GCSConfig.ACL = "bucketOwnerRead" u2.FsConfig.GCSConfig.UploadPartSize = 5 u2.FsConfig.GCSConfig.UploadPartMaxTime = 20 user2, _, err := httpdtest.AddUser(u2, http.StatusCreated) assert.NoError(t, err) u3 := getTestUser() u3.Username = usernames[2] u3.FsConfig.Provider = sdk.AzureBlobFilesystemProvider u3.FsConfig.AzBlobConfig.Container = "test" u3.FsConfig.AzBlobConfig.AccountName = "Server-Account-Name" u3.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("Server-Account-Key") user3, _, err := httpdtest.AddUser(u3, http.StatusCreated) assert.NoError(t, err) u4 := getTestUser() u4.Username = usernames[3] u4.FsConfig.Provider = sdk.CryptedFilesystemProvider u4.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("test passphrase") user4, _, err := httpdtest.AddUser(u4, http.StatusCreated) assert.NoError(t, err) u5 := getTestUser() u5.Username = usernames[4] u5.FsConfig.Provider = sdk.SFTPFilesystemProvider u5.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:2022" u5.FsConfig.SFTPConfig.Username = "sftp_user" u5.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("apassword") u5.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(sftpPrivateKey) u5.FsConfig.SFTPConfig.Fingerprints = []string{sftpPkeyFingerprint} u5.FsConfig.SFTPConfig.Prefix = "/prefix" user5, _, err := httpdtest.AddUser(u5, http.StatusCreated) assert.NoError(t, err) u6 := getTestUser() u6.Username = usernames[5] u6.FsConfig.Provider = sdk.HTTPFilesystemProvider u6.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: "http://127.0.0.1/api/v1", Username: defaultUsername, }, Password: kms.NewPlainSecret(defaultPassword), APIKey: kms.NewPlainSecret(defaultTokenAuthUser), } user6, _, err := httpdtest.AddUser(u6, http.StatusCreated) assert.NoError(t, err) users, _, err := httpdtest.GetUsers(0, 0, http.StatusOK) assert.NoError(t, err) assert.GreaterOrEqual(t, len(users), 6) for _, username := range usernames { user, _, err := httpdtest.GetUserByUsername(username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user.Password) assert.True(t, user.HasPassword) } user1, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user1.Password) assert.Empty(t, user1.FsConfig.S3Config.AccessSecret.GetKey()) assert.Empty(t, user1.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetStatus()) assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetPayload()) assert.Empty(t, user1.FsConfig.S3Config.SSECustomerKey.GetKey()) assert.Empty(t, user1.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetStatus()) assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetPayload()) user2, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user2.Password) assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetKey()) assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetStatus()) assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetPayload()) user3, _, err = httpdtest.GetUserByUsername(user3.Username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user3.Password) assert.Empty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetKey()) assert.Empty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetPayload()) user4, _, err = httpdtest.GetUserByUsername(user4.Username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user4.Password) assert.Empty(t, user4.FsConfig.CryptConfig.Passphrase.GetKey()) assert.Empty(t, user4.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetPayload()) user5, _, err = httpdtest.GetUserByUsername(user5.Username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user5.Password) assert.Empty(t, user5.FsConfig.SFTPConfig.Password.GetKey()) assert.Empty(t, user5.FsConfig.SFTPConfig.Password.GetAdditionalData()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetStatus()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetPayload()) assert.Empty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetKey()) assert.Empty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetPayload()) assert.Equal(t, "/prefix", user5.FsConfig.SFTPConfig.Prefix) user6, _, err = httpdtest.GetUserByUsername(user6.Username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user6.Password) assert.Empty(t, user6.FsConfig.HTTPConfig.Password.GetKey()) assert.Empty(t, user6.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.NotEmpty(t, user6.FsConfig.HTTPConfig.APIKey.GetStatus()) assert.NotEmpty(t, user6.FsConfig.HTTPConfig.APIKey.GetPayload()) // finally check that we have all the data inside the data provider user1, err = dataprovider.UserExists(user1.Username, "") assert.NoError(t, err) assert.NotEmpty(t, user1.Password) assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetKey()) assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetStatus()) assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetPayload()) assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetKey()) assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetStatus()) assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetPayload()) err = user1.FsConfig.S3Config.AccessSecret.Decrypt() assert.NoError(t, err) err = user1.FsConfig.S3Config.SSECustomerKey.Decrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, user1.FsConfig.S3Config.AccessSecret.GetStatus()) assert.Equal(t, u1.FsConfig.S3Config.AccessSecret.GetPayload(), user1.FsConfig.S3Config.AccessSecret.GetPayload()) assert.Empty(t, user1.FsConfig.S3Config.AccessSecret.GetKey()) assert.Empty(t, user1.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusPlain, user1.FsConfig.S3Config.SSECustomerKey.GetStatus()) assert.Equal(t, u1.FsConfig.S3Config.SSECustomerKey.GetPayload(), user1.FsConfig.S3Config.SSECustomerKey.GetPayload()) assert.Empty(t, user1.FsConfig.S3Config.SSECustomerKey.GetKey()) assert.Empty(t, user1.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) user2, err = dataprovider.UserExists(user2.Username, "") assert.NoError(t, err) assert.NotEmpty(t, user2.Password) assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetKey()) assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetStatus()) assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetPayload()) err = user2.FsConfig.GCSConfig.Credentials.Decrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, user2.FsConfig.GCSConfig.Credentials.GetStatus()) assert.Equal(t, u2.FsConfig.GCSConfig.Credentials.GetPayload(), user2.FsConfig.GCSConfig.Credentials.GetPayload()) assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetKey()) assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetAdditionalData()) user3, err = dataprovider.UserExists(user3.Username, "") assert.NoError(t, err) assert.NotEmpty(t, user3.Password) assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetKey()) assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetPayload()) err = user3.FsConfig.AzBlobConfig.AccountKey.Decrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, user3.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.Equal(t, u3.FsConfig.AzBlobConfig.AccountKey.GetPayload(), user3.FsConfig.AzBlobConfig.AccountKey.GetPayload()) assert.Empty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetKey()) assert.Empty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) user4, err = dataprovider.UserExists(user4.Username, "") assert.NoError(t, err) assert.NotEmpty(t, user4.Password) assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetKey()) assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetPayload()) err = user4.FsConfig.CryptConfig.Passphrase.Decrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, user4.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.Equal(t, u4.FsConfig.CryptConfig.Passphrase.GetPayload(), user4.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, user4.FsConfig.CryptConfig.Passphrase.GetKey()) assert.Empty(t, user4.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) user5, err = dataprovider.UserExists(user5.Username, "") assert.NoError(t, err) assert.NotEmpty(t, user5.Password) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetKey()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetAdditionalData()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetStatus()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetPayload()) err = user5.FsConfig.SFTPConfig.Password.Decrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, user5.FsConfig.SFTPConfig.Password.GetStatus()) assert.Equal(t, u5.FsConfig.SFTPConfig.Password.GetPayload(), user5.FsConfig.SFTPConfig.Password.GetPayload()) assert.Empty(t, user5.FsConfig.SFTPConfig.Password.GetKey()) assert.Empty(t, user5.FsConfig.SFTPConfig.Password.GetAdditionalData()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetKey()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetPayload()) err = user5.FsConfig.SFTPConfig.PrivateKey.Decrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, user5.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.Equal(t, u5.FsConfig.SFTPConfig.PrivateKey.GetPayload(), user5.FsConfig.SFTPConfig.PrivateKey.GetPayload()) assert.Empty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetKey()) assert.Empty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) user6, err = dataprovider.UserExists(user6.Username, "") assert.NoError(t, err) assert.NotEmpty(t, user6.Password) assert.NotEmpty(t, user6.FsConfig.HTTPConfig.Password.GetKey()) assert.NotEmpty(t, user6.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.NotEmpty(t, user6.FsConfig.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, user6.FsConfig.HTTPConfig.Password.GetPayload()) err = user6.FsConfig.HTTPConfig.Password.Decrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, user6.FsConfig.HTTPConfig.Password.GetStatus()) assert.Equal(t, u6.FsConfig.HTTPConfig.Password.GetPayload(), user6.FsConfig.HTTPConfig.Password.GetPayload()) assert.Empty(t, user6.FsConfig.HTTPConfig.Password.GetKey()) assert.Empty(t, user6.FsConfig.HTTPConfig.Password.GetAdditionalData()) // update the GCS user and check that the credentials are preserved user2.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret() user2.FsConfig.GCSConfig.ACL = "private" _, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") assert.NoError(t, err) user2, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) assert.NoError(t, err) assert.Empty(t, user2.Password) assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetKey()) assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetStatus()) assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetPayload()) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user3, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user4, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user5, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user6, http.StatusOK) assert.NoError(t, err) } func TestSecretObject(t *testing.T) { s := kms.NewPlainSecret("test data") s.SetAdditionalData("username") require.True(t, s.IsValid()) err := s.Encrypt() require.NoError(t, err) require.Equal(t, sdkkms.SecretStatusSecretBox, s.GetStatus()) require.NotEmpty(t, s.GetPayload()) require.NotEmpty(t, s.GetKey()) require.True(t, s.IsValid()) err = s.Decrypt() require.NoError(t, err) require.Equal(t, sdkkms.SecretStatusPlain, s.GetStatus()) require.Equal(t, "test data", s.GetPayload()) require.Empty(t, s.GetKey()) } func TestSecretObjectCompatibility(t *testing.T) { // this is manually tested against vault too testPayload := "test payload" s := kms.NewPlainSecret(testPayload) require.True(t, s.IsValid()) err := s.Encrypt() require.NoError(t, err) localAsJSON, err := json.Marshal(s) assert.NoError(t, err) for _, secretStatus := range []string{sdkkms.SecretStatusSecretBox} { kmsConfig := config.GetKMSConfig() assert.Empty(t, kmsConfig.Secrets.MasterKeyPath) if secretStatus == sdkkms.SecretStatusVaultTransit { os.Setenv("VAULT_SERVER_URL", "http://127.0.0.1:8200") os.Setenv("VAULT_SERVER_TOKEN", "s.9lYGq83MbgG5KR5kfebXVyhJ") kmsConfig.Secrets.URL = "hashivault://mykey" } err := kmsConfig.Initialize() assert.NoError(t, err) // encrypt without a master key secret := kms.NewPlainSecret(testPayload) secret.SetAdditionalData("add data") err = secret.Encrypt() assert.NoError(t, err) assert.Equal(t, 0, secret.GetMode()) secretClone := secret.Clone() err = secretClone.Decrypt() assert.NoError(t, err) assert.Equal(t, testPayload, secretClone.GetPayload()) if secretStatus == sdkkms.SecretStatusVaultTransit { // decrypt the local secret now that the provider is vault secretLocal := kms.NewEmptySecret() err = json.Unmarshal(localAsJSON, secretLocal) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, secretLocal.GetStatus()) assert.Equal(t, 0, secretLocal.GetMode()) err = secretLocal.Decrypt() assert.NoError(t, err) assert.Equal(t, testPayload, secretLocal.GetPayload()) assert.Equal(t, sdkkms.SecretStatusPlain, secretLocal.GetStatus()) err = secretLocal.Encrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, secretLocal.GetStatus()) assert.Equal(t, 0, secretLocal.GetMode()) } asJSON, err := json.Marshal(secret) assert.NoError(t, err) masterKeyPath := filepath.Join(os.TempDir(), "mkey") err = os.WriteFile(masterKeyPath, []byte("test key"), os.ModePerm) assert.NoError(t, err) config := kms.Configuration{ Secrets: kms.Secrets{ MasterKeyPath: masterKeyPath, }, } if secretStatus == sdkkms.SecretStatusVaultTransit { config.Secrets.URL = "hashivault://mykey" } err = config.Initialize() assert.NoError(t, err) // now build the secret from JSON secret = kms.NewEmptySecret() err = json.Unmarshal(asJSON, secret) assert.NoError(t, err) assert.Equal(t, 0, secret.GetMode()) err = secret.Decrypt() assert.NoError(t, err) assert.Equal(t, testPayload, secret.GetPayload()) err = secret.Encrypt() assert.NoError(t, err) assert.Equal(t, 1, secret.GetMode()) err = secret.Decrypt() assert.NoError(t, err) assert.Equal(t, testPayload, secret.GetPayload()) if secretStatus == sdkkms.SecretStatusVaultTransit { // decrypt the local secret encryped without a master key now that // the provider is vault and a master key is set. // The provider will not change, the master key will be used secretLocal := kms.NewEmptySecret() err = json.Unmarshal(localAsJSON, secretLocal) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, secretLocal.GetStatus()) assert.Equal(t, 0, secretLocal.GetMode()) err = secretLocal.Decrypt() assert.NoError(t, err) assert.Equal(t, testPayload, secretLocal.GetPayload()) assert.Equal(t, sdkkms.SecretStatusPlain, secretLocal.GetStatus()) err = secretLocal.Encrypt() assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, secretLocal.GetStatus()) assert.Equal(t, 1, secretLocal.GetMode()) } err = kmsConfig.Initialize() assert.NoError(t, err) err = os.Remove(masterKeyPath) assert.NoError(t, err) if secretStatus == sdkkms.SecretStatusVaultTransit { os.Unsetenv("VAULT_SERVER_URL") os.Unsetenv("VAULT_SERVER_TOKEN") } } } func TestUpdateUserNoCredentials(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.Password = "" user.PublicKeys = []string{} // password and public key will be omitted from json serialization if empty and so they will remain unchanged // and no validation error will be raised _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUpdateUserEmptyHomeDir(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.HomeDir = "" _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUpdateUserInvalidHomeDir(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.HomeDir = "relative_path" _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUpdateNonExistentUser(t *testing.T) { _, _, err := httpdtest.UpdateUser(getTestUser(), http.StatusNotFound, "") assert.NoError(t, err) } func TestGetNonExistentUser(t *testing.T) { _, _, err := httpdtest.GetUserByUsername("na", http.StatusNotFound) assert.NoError(t, err) } func TestDeleteNonExistentUser(t *testing.T) { _, err := httpdtest.RemoveUser(getTestUser(), http.StatusNotFound) assert.NoError(t, err) } func TestAddDuplicateUser(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) _, _, err = httpdtest.AddUser(getTestUser(), http.StatusConflict) assert.NoError(t, err) _, _, err = httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.Error(t, err, "adding a duplicate user must fail") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestGetUsers(t *testing.T) { user1, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Username = defaultUsername + "1" user2, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) users, _, err := httpdtest.GetUsers(0, 0, http.StatusOK) assert.NoError(t, err) assert.GreaterOrEqual(t, len(users), 2) for _, user := range users { if u.Username == user.Username { assert.True(t, user.HasPassword) } } users, _, err = httpdtest.GetUsers(1, 0, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, len(users)) users, _, err = httpdtest.GetUsers(1, 1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, len(users)) _, _, err = httpdtest.GetUsers(1, 1, http.StatusInternalServerError) assert.Error(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) } func TestGetQuotaScans(t *testing.T) { _, _, err := httpdtest.GetQuotaScans(http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetQuotaScans(http.StatusInternalServerError) assert.Error(t, err) _, _, err = httpdtest.GetFoldersQuotaScans(http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetFoldersQuotaScans(http.StatusInternalServerError) assert.Error(t, err) } func TestStartQuotaScan(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) folder := vfs.BaseVirtualFolder{ Name: "vfolder", MappedPath: filepath.Join(os.TempDir(), "folder"), Description: "virtual folder", } _, _, err = httpdtest.AddFolder(folder, http.StatusCreated) assert.NoError(t, err) _, err = httpdtest.StartFolderQuotaScan(folder, http.StatusAccepted) assert.NoError(t, err) for { quotaScan, _, err := httpdtest.GetFoldersQuotaScans(http.StatusOK) if !assert.NoError(t, err, "Error getting active scans") { break } if len(quotaScan) == 0 { break } time.Sleep(100 * time.Millisecond) } _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) } func TestUpdateFolderQuotaUsage(t *testing.T) { f := vfs.BaseVirtualFolder{ Name: "vdir", MappedPath: filepath.Join(os.TempDir(), "folder"), } usedQuotaFiles := 1 usedQuotaSize := int64(65535) f.UsedQuotaFiles = usedQuotaFiles f.UsedQuotaSize = usedQuotaSize folder, _, err := httpdtest.AddFolder(f, http.StatusCreated) if assert.NoError(t, err) { assert.Equal(t, usedQuotaFiles, folder.UsedQuotaFiles) assert.Equal(t, usedQuotaSize, folder.UsedQuotaSize) } _, err = httpdtest.UpdateFolderQuotaUsage(folder, "invalid mode", http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.UpdateFolderQuotaUsage(f, "reset", http.StatusOK) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(f.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles, folder.UsedQuotaFiles) assert.Equal(t, usedQuotaSize, folder.UsedQuotaSize) _, err = httpdtest.UpdateFolderQuotaUsage(f, "add", http.StatusOK) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(f.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2*usedQuotaFiles, folder.UsedQuotaFiles) assert.Equal(t, 2*usedQuotaSize, folder.UsedQuotaSize) f.UsedQuotaSize = -1 _, err = httpdtest.UpdateFolderQuotaUsage(f, "", http.StatusBadRequest) assert.NoError(t, err) f.UsedQuotaSize = usedQuotaSize f.Name = f.Name + "1" _, err = httpdtest.UpdateFolderQuotaUsage(f, "", http.StatusNotFound) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) } func TestGetVersion(t *testing.T) { _, _, err := httpdtest.GetVersion(http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetVersion(http.StatusInternalServerError) assert.Error(t, err, "get version request must succeed, we requested to check a wrong status code") } func TestGetStatus(t *testing.T) { _, _, err := httpdtest.GetStatus(http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetStatus(http.StatusBadRequest) assert.Error(t, err, "get provider status request must succeed, we requested to check a wrong status code") } func TestGetConnections(t *testing.T) { _, _, err := httpdtest.GetConnections(http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetConnections(http.StatusInternalServerError) assert.Error(t, err, "get sftp connections request must succeed, we requested to check a wrong status code") } func TestCloseActiveConnection(t *testing.T) { _, err := httpdtest.CloseConnection("non_existent_id", http.StatusNotFound) assert.NoError(t, err) user := getTestUser() c := common.NewBaseConnection("connID", common.ProtocolSFTP, "", "", user) fakeConn := &fakeConnection{ BaseConnection: c, } err = common.Connections.Add(fakeConn) assert.NoError(t, err) _, err = httpdtest.CloseConnection(c.GetID(), http.StatusOK) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) c := common.NewBaseConnection("connID", common.ProtocolFTP, "", "", user) fakeConn := &fakeConnection{ BaseConnection: c, } err = common.Connections.Add(fakeConn) assert.NoError(t, err) c1 := common.NewBaseConnection("connID1", common.ProtocolSFTP, "", "", user) fakeConn1 := &fakeConnection{ BaseConnection: c1, } err = common.Connections.Add(fakeConn1) assert.NoError(t, err) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "0") assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 2) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "1") assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) err = common.Connections.Add(fakeConn) assert.NoError(t, err) err = common.Connections.Add(fakeConn1) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 2) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestAdminGenerateRecoveryCodesSaveError(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.NamingRules = 7 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) a := getTestAdmin() a.Username = "adMiN@example.com " admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], admin.Username) assert.NoError(t, err) admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), } admin.Password = defaultTokenAuthPass err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(a.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) adminAPIToken, err := getJWTAPITokenFromTestServerWithPasscode(a.Username, defaultTokenAuthPass, passcode) assert.NoError(t, err) assert.NotEmpty(t, adminAPIToken) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { return } req, err := http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, adminAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following characters are allowed") _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestAdminCredentialsWithSpaces(t *testing.T) { a := getTestAdmin() a.Username = xid.New().String() a.Password = " " + xid.New().String() + " " admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) // For admins the password is always trimmed. _, err = getJWTAPITokenFromTestServer(a.Username, a.Password) assert.Error(t, err) _, err = getJWTAPITokenFromTestServer(a.Username, strings.TrimSpace(a.Password)) assert.NoError(t, err) // The password sent from the WebAdmin UI is automatically trimmed _, err = getJWTWebToken(a.Username, a.Password) assert.NoError(t, err) _, err = getJWTWebToken(a.Username, strings.TrimSpace(a.Password)) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestUserCredentialsWithSpaces(t *testing.T) { u := getTestUser() u.Password = " " + xid.New().String() + " " user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) // For users the password is not trimmed _, err = getJWTAPIUserTokenFromTestServer(u.Username, u.Password) assert.NoError(t, err) _, err = getJWTAPIUserTokenFromTestServer(u.Username, strings.TrimSpace(u.Password)) assert.Error(t, err) _, err = getJWTWebClientTokenFromTestServer(u.Username, u.Password) assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(u.Username, strings.TrimSpace(u.Password)) assert.Error(t, err) user.Password = u.Password conn, sftpClient, err := getSftpClient(user) if assert.NoError(t, err) { conn.Close() sftpClient.Close() } user.Password = strings.TrimSpace(u.Password) _, _, err = getSftpClient(user) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestNamingRules(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 3525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.NamingRules = 7 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) u := getTestUser() u.Username = " uSeR@user.me " u.Email = dataprovider.ConvertName(u.Username) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, "user@user.me", user.Username) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } user.Password = u.Password err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) user.Username = u.Username user.AdditionalInfo = "info" user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(u.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.TOTPConfig.Enabled) r := getTestRole() r.Name = "role@mycompany" role, _, err := httpdtest.AddRole(r, http.StatusCreated) assert.NoError(t, err) a := getTestAdmin() a.Username = "admiN@example.com " admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, "admin@example.com", admin.Username) admin.Email = dataprovider.ConvertName(a.Username) admin.Username = a.Username admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(a.Username, http.StatusOK) assert.NoError(t, err) f := vfs.BaseVirtualFolder{ Name: "文件夹AB", MappedPath: filepath.Clean(os.TempDir()), } folder, resp, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Equal(t, "文件夹ab", folder.Name) folder.Name = f.Name folder.Description = folder.Name _, resp, err = httpdtest.UpdateFolder(folder, http.StatusOK) assert.NoError(t, err, string(resp)) folder, resp, err = httpdtest.GetFolderByName(f.Name, http.StatusOK) assert.NoError(t, err, string(resp)) assert.Equal(t, "文件夹AB", folder.Description) _, err = httpdtest.RemoveFolder(f, http.StatusOK) assert.NoError(t, err) token, err := getJWTWebClientTokenFromTestServer(u.Username, defaultPassword) assert.NoError(t, err) assert.NotEmpty(t, token) adminAPIToken, err := getJWTAPITokenFromTestServer(a.Username, defaultTokenAuthPass) assert.NoError(t, err) assert.NotEmpty(t, adminAPIToken) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { return } token, err = getJWTWebClientTokenFromTestServer(user.Username, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) req, err := http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidUser) // test user reset password. Setting the new password will fail because the username is not valid loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = make(url.Values) form.Set("username", user.Username) form.Set(csrfFormToken, csrfToken) lastResetCode = "" req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("code", lastResetCode) form.Set("password", defaultPassword) form.Set("confirm_password", defaultPassword) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidUser) adminAPIToken, err = getJWTAPITokenFromTestServer(admin.Username, defaultTokenAuthPass) assert.NoError(t, err) userAPIToken, err := getJWTAPIUserTokenFromTestServer(user.Username, defaultPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userPath+"/"+user.Username+"/2fa/disable", nil) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, adminAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following characters are allowed") req, err = http.NewRequest(http.MethodPost, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, userAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following characters are allowed") apiKeyAuthReq := make(map[string]bool) apiKeyAuthReq["allow_api_key_auth"] = true asJSON, err := json.Marshal(apiKeyAuthReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, userAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following characters are allowed") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) token, err = getJWTWebTokenFromTestServer(admin.Username, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminProfilePath, token) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidUser) req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidName) apiKeyAuthReq = make(map[string]bool) apiKeyAuthReq["allow_api_key_auth"] = true asJSON, err = json.Marshal(apiKeyAuthReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, adminAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following characters are allowed") // test admin reset password loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = make(url.Values) form.Set("username", admin.Username) form.Set(csrfFormToken, csrfToken) lastResetCode = "" req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("code", lastResetCode) form.Set("password", defaultPassword) form.Set("confirm_password", defaultPassword) req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestUserPassword(t *testing.T) { u := getTestUser() u.Password = "" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.False(t, user.HasPassword) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.False(t, user.HasPassword) user.Password = defaultPassword user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.True(t, user.HasPassword) token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) rawUser := map[string]any{ "username": user.Username, "home_dir": filepath.Join(homeBasePath, defaultUsername), "permissions": map[string][]string{ "/": {"*"}, }, } userAsJSON, err := json.Marshal(rawUser) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // the previous password must be preserved user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.HasPassword) // update the user with an empty password field, the password will be unset rawUser["password"] = "" userAsJSON, err = json.Marshal(rawUser) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, user.HasPassword) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSaveErrors(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.NamingRules = 1 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) recCode := "recovery code" recoveryCodes := []dataprovider.RecoveryCode{ { Secret: kms.NewPlainSecret(recCode), Used: false, }, } u := getTestUser() u.Username = "user@example.com" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = u.Password user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH, common.ProtocolHTTP}, } user.Filters.RecoveryCodes = recoveryCodes err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.TOTPConfig.Enabled) assert.Len(t, user.Filters.RecoveryCodes, 1) a := getTestAdmin() a.Username = "admin@example.com" admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) admin.Email = admin.Username admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) admin.Password = a.Password admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), } admin.Filters.RecoveryCodes = recoveryCodes err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) assert.Len(t, admin.Filters.RecoveryCodes, 1) r := getTestRole() r.Name = "role@mycompany" role, _, err := httpdtest.AddRole(r, http.StatusCreated) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { return } _, resp, err := httpdtest.UpdateRole(role, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "the following characters are allowed") loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(a.Username, a.Password, csrfToken) req, err := http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err := getCookieFromResponse(rr) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recCode) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nError500Message) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(u.Username, u.Password, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recCode) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nError500Message) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) } func TestUserBaseDir(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.UsersBaseDir = homeBasePath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) u := getTestUser() u.HomeDir = "" user, _, err := httpdtest.AddUser(u, http.StatusCreated) if assert.Error(t, err) { assert.EqualError(t, err, "home dir mismatch") } assert.Equal(t, filepath.Join(providerConf.UsersBaseDir, u.Username), user.HomeDir) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestQuotaTrackingDisabled(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.TrackQuota = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) // user quota scan must fail user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) _, err = httpdtest.StartQuotaScan(user, http.StatusForbidden) assert.NoError(t, err) _, err = httpdtest.UpdateQuotaUsage(user, "", http.StatusForbidden) assert.NoError(t, err) _, err = httpdtest.UpdateTransferQuotaUsage(user, "", http.StatusForbidden) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) // folder quota scan must fail folder := vfs.BaseVirtualFolder{ Name: "folder_quota_test", MappedPath: filepath.Clean(os.TempDir()), } folder, resp, err := httpdtest.AddFolder(folder, http.StatusCreated) assert.NoError(t, err, string(resp)) _, err = httpdtest.StartFolderQuotaScan(folder, http.StatusForbidden) assert.NoError(t, err) _, err = httpdtest.UpdateFolderQuotaUsage(folder, "", http.StatusForbidden) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestProviderErrors(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) userAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) userWebToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) token, _, err := httpdtest.GetToken(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) testServerToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) httpdtest.SetJWTToken(token) err = dataprovider.Close() assert.NoError(t, err) _, _, err = httpdtest.GetUserByUsername("na", http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetUsers(1, 0, http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetGroups(1, 0, http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetAdmins(1, 0, http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetAPIKeys(1, 0, http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetEventActions(1, 0, http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetEventRules(1, 0, http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeDefender, "", "", dataprovider.OrderASC, 10, http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetRoles(1, 0, http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.UpdateRole(getTestRole(), http.StatusInternalServerError) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, userSharesPath, nil) assert.NoError(t, err) setBearerForReq(req, userAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodPost, userSharesPath, nil) assert.NoError(t, err) setBearerForReq(req, userAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) // password reset errors loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) form.Set("username", "username") form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorGetUser) getJSONShares := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, err := http.NewRequest(http.MethodGet, webClientSharesPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, userWebToken) executeRequest(req) } getJSONShares() req, err = http.NewRequest(http.MethodGet, webClientSharePath, nil) assert.NoError(t, err) setJWTCookieForReq(req, userWebToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, webClientSharePath+"/shareID", nil) assert.NoError(t, err) setJWTCookieForReq(req, userWebToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/shareID", nil) assert.NoError(t, err) setJWTCookieForReq(req, userWebToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) _, _, err = httpdtest.UpdateUser(dataprovider.User{BaseUser: sdk.BaseUser{Username: "auser"}}, http.StatusInternalServerError, "") assert.NoError(t, err) _, err = httpdtest.RemoveUser(dataprovider.User{BaseUser: sdk.BaseUser{Username: "auser"}}, http.StatusInternalServerError) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: "aname"}, http.StatusInternalServerError) assert.NoError(t, err) status, _, err := httpdtest.GetStatus(http.StatusOK) if assert.NoError(t, err) { assert.False(t, status.DataProvider.IsActive) } _, _, err = httpdtest.Dumpdata("backup.json", "", "", http.StatusInternalServerError) assert.NoError(t, err) _, _, err = httpdtest.GetFolders(0, 0, http.StatusInternalServerError) assert.NoError(t, err) user = getTestUser() user.ID = 1 backupData := dataprovider.BackupData{ Version: dataprovider.DumpVersion, } backupData.Configs = &dataprovider.Configs{} backupData.Users = append(backupData.Users, user) backupContent, err := json.Marshal(backupData) assert.NoError(t, err) backupFilePath := filepath.Join(backupsPath, "backup.json") err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData.Configs = nil backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData.Folders = append(backupData.Folders, vfs.BaseVirtualFolder{Name: "testFolder", MappedPath: filepath.Clean(os.TempDir())}) backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData.Users = nil backupData.Folders = nil backupData.Groups = append(backupData.Groups, getTestGroup()) backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData.Groups = nil backupData.Admins = append(backupData.Admins, getTestAdmin()) backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData.Users = nil backupData.Folders = nil backupData.Admins = nil backupData.APIKeys = append(backupData.APIKeys, dataprovider.APIKey{ Name: "name", KeyID: util.GenerateUniqueID(), Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), Scope: dataprovider.APIKeyScopeUser, }) backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData.APIKeys = nil backupData.Shares = append(backupData.Shares, dataprovider.Share{ Name: util.GenerateUniqueID(), ShareID: util.GenerateUniqueID(), Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, Username: defaultUsername, }) backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, resp, err := httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err, string(resp)) backupData = dataprovider.BackupData{ EventActions: []dataprovider.BaseEventAction{ { Name: "quota_reset", Type: dataprovider.ActionTypeFolderQuotaReset, }, }, Version: dataprovider.DumpVersion, } backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData = dataprovider.BackupData{ EventRules: []dataprovider.EventRule{ { Name: "quota_reset", Trigger: dataprovider.EventTriggerSchedule, Conditions: dataprovider.EventConditions{ Schedules: []dataprovider.Schedule{ { Hours: "2", DayOfWeek: "1", DayOfMonth: "2", Month: "3", }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: "unknown action", }, Order: 1, }, }, }, }, Version: dataprovider.DumpVersion, } backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData = dataprovider.BackupData{ Roles: []dataprovider.Role{ { Name: "role1", }, }, Version: dataprovider.DumpVersion, } backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) backupData = dataprovider.BackupData{ IPLists: []dataprovider.IPListEntry{ { IPOrNet: "192.168.1.1/24", Type: dataprovider.IPListTypeRateLimiterSafeList, Mode: dataprovider.ListModeAllow, }, }, Version: dataprovider.DumpVersion, } backupContent, err = json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) assert.NoError(t, err) err = os.Remove(backupFilePath) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webUserPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, webTemplateUser+"?from=auser", nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, webGroupPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, webAdminPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, webTemplateUser, nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webGroupPath, "groupname"), nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, "grpname"), nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, webTemplateFolder+"?from=afolder", nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventActionPath, "actionname"), nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, "actionname"), bytes.NewBuffer(nil)) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) getJSONActions := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, err := http.NewRequest(http.MethodGet, webAdminEventActionsPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) executeRequest(req) } getJSONActions() req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventRulePath, "rulename"), nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, "rulename"), bytes.NewBuffer(nil)) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) getJSONRules := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, err := http.NewRequest(http.MethodGet, webAdminEventRulesPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) executeRequest(req) } getJSONRules() req, err = http.NewRequest(http.MethodGet, webAdminEventRulePath, nil) assert.NoError(t, err) setJWTCookieForReq(req, testServerToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) httpdtest.SetJWTToken("") } func TestFolders(t *testing.T) { folder := vfs.BaseVirtualFolder{ Name: "name", MappedPath: "relative path", Users: []string{"1", "2", "3"}, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret("asecret"), }, }, } _, _, err := httpdtest.AddFolder(folder, http.StatusBadRequest) assert.NoError(t, err) folder.MappedPath = filepath.Clean(os.TempDir()) folder1, resp, err := httpdtest.AddFolder(folder, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Equal(t, folder.Name, folder1.Name) assert.Equal(t, folder.MappedPath, folder1.MappedPath) assert.Equal(t, 0, folder1.UsedQuotaFiles) assert.Equal(t, int64(0), folder1.UsedQuotaSize) assert.Equal(t, int64(0), folder1.LastQuotaUpdate) assert.Equal(t, sdkkms.SecretStatusSecretBox, folder1.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, folder1.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, folder1.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, folder1.FsConfig.CryptConfig.Passphrase.GetKey()) assert.Len(t, folder1.Users, 0) // adding a duplicate folder must fail _, _, err = httpdtest.AddFolder(folder, http.StatusCreated) assert.Error(t, err) folder.MappedPath = filepath.Join(os.TempDir(), "vfolder") folder.Name = filepath.Base(folder.MappedPath) folder.UsedQuotaFiles = 1 folder.UsedQuotaSize = 345 folder.LastQuotaUpdate = 10 folder2, _, err := httpdtest.AddFolder(folder, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Equal(t, 1, folder2.UsedQuotaFiles) assert.Equal(t, int64(345), folder2.UsedQuotaSize) assert.Equal(t, int64(10), folder2.LastQuotaUpdate) assert.Len(t, folder2.Users, 0) folders, _, err := httpdtest.GetFolders(0, 0, http.StatusOK) assert.NoError(t, err) numResults := len(folders) assert.GreaterOrEqual(t, numResults, 2) found := false for _, f := range folders { if f.Name == folder1.Name { found = true assert.Equal(t, folder1.MappedPath, f.MappedPath) assert.Equal(t, sdkkms.SecretStatusSecretBox, f.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, f.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetKey()) assert.Len(t, f.Users, 0) } } assert.True(t, found) folders, _, err = httpdtest.GetFolders(0, 1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folders, numResults-1) folders, _, err = httpdtest.GetFolders(1, 0, http.StatusOK) assert.NoError(t, err) assert.Len(t, folders, 1) f, _, err := httpdtest.GetFolderByName(folder1.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, folder1.Name, f.Name) assert.Equal(t, folder1.MappedPath, f.MappedPath) assert.Equal(t, sdkkms.SecretStatusSecretBox, f.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, f.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetKey()) assert.Len(t, f.Users, 0) f, _, err = httpdtest.GetFolderByName(folder2.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, folder2.Name, f.Name) assert.Equal(t, folder2.MappedPath, f.MappedPath) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{ Name: "invalid", }, http.StatusNotFound) assert.NoError(t, err) _, _, err = httpdtest.UpdateFolder(vfs.BaseVirtualFolder{Name: "notfound"}, http.StatusNotFound) assert.NoError(t, err) folder1.MappedPath = "a/relative/path" _, _, err = httpdtest.UpdateFolder(folder1, http.StatusBadRequest) assert.NoError(t, err) folder1.MappedPath = filepath.Join(os.TempDir(), "updated") folder1.Description = "updated folder description" f, resp, err = httpdtest.UpdateFolder(folder1, http.StatusOK) assert.NoError(t, err, string(resp)) assert.Equal(t, folder1.MappedPath, f.MappedPath) assert.Equal(t, folder1.Description, f.Description) _, err = httpdtest.RemoveFolder(folder1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder2, http.StatusOK) assert.NoError(t, err) } func TestFolderRelations(t *testing.T) { mappedPath := filepath.Join(os.TempDir(), "mapped_path") name := filepath.Base(mappedPath) u := getTestUser() u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: name, }, VirtualPath: "/mountu", }) _, resp, err := httpdtest.AddUser(u, http.StatusInternalServerError) assert.NoError(t, err, string(resp)) g := getTestGroup() g.VirtualFolders = append(g.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: name, }, VirtualPath: "/mountg", }) _, resp, err = httpdtest.AddGroup(g, http.StatusInternalServerError) assert.NoError(t, err, string(resp)) f := vfs.BaseVirtualFolder{ Name: name, MappedPath: mappedPath, } folder, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) assert.Len(t, folder.Users, 0) assert.Len(t, folder.Groups, 0) user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) group, resp, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err, string(resp)) folder, _, err = httpdtest.GetFolderByName(folder.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Len(t, folder.Groups, 1) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user.VirtualFolders, 1) { assert.Equal(t, mappedPath, user.VirtualFolders[0].MappedPath) } group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) if assert.Len(t, group.VirtualFolders, 1) { assert.Equal(t, mappedPath, group.VirtualFolders[0].MappedPath) } // update the folder and check the modified field on user and group mappedPath = filepath.Join(os.TempDir(), "mapped_path") folder.MappedPath = mappedPath _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user.VirtualFolders, 1) { assert.Equal(t, mappedPath, user.VirtualFolders[0].MappedPath) } group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) if assert.Len(t, group.VirtualFolders, 1) { assert.Equal(t, mappedPath, group.VirtualFolders[0].MappedPath) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folder.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 0) assert.Len(t, folder.Groups, 0) user, resp, err = httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Len(t, user.VirtualFolders, 1) group, resp, err = httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Len(t, group.VirtualFolders, 1) folder, _, err = httpdtest.GetFolderByName(folder.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, folder.Users, 1) assert.Len(t, folder.Groups, 1) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.VirtualFolders, 0) group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group.VirtualFolders, 0) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestDumpdata(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) _, rawResp, err := httpdtest.Dumpdata("", "", "", http.StatusBadRequest) assert.NoError(t, err, string(rawResp)) _, _, err = httpdtest.Dumpdata(filepath.Join(backupsPath, "backup.json"), "", "", http.StatusBadRequest) assert.NoError(t, err) _, rawResp, err = httpdtest.Dumpdata("../backup.json", "", "", http.StatusBadRequest) assert.NoError(t, err, string(rawResp)) _, rawResp, err = httpdtest.Dumpdata("backup.json", "", "0", http.StatusOK) assert.NoError(t, err, string(rawResp)) response, _, err := httpdtest.Dumpdata("", "1", "0", http.StatusOK) assert.NoError(t, err) _, ok := response["admins"] assert.True(t, ok) _, ok = response["users"] assert.True(t, ok) _, ok = response["groups"] assert.True(t, ok) _, ok = response["folders"] assert.True(t, ok) _, ok = response["api_keys"] assert.True(t, ok) _, ok = response["shares"] assert.True(t, ok) _, ok = response["version"] assert.True(t, ok) _, rawResp, err = httpdtest.Dumpdata("backup.json", "", "1", http.StatusOK) assert.NoError(t, err, string(rawResp)) err = os.Remove(filepath.Join(backupsPath, "backup.json")) assert.NoError(t, err) if runtime.GOOS != osWindows { err = os.Chmod(backupsPath, 0001) assert.NoError(t, err) _, _, err = httpdtest.Dumpdata("bck.json", "", "", http.StatusForbidden) assert.NoError(t, err) // subdir cannot be created _, _, err = httpdtest.Dumpdata(filepath.Join("subdir", "bck.json"), "", "", http.StatusForbidden) assert.NoError(t, err) err = os.Chmod(backupsPath, 0755) assert.NoError(t, err) } err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestDefenderAPI(t *testing.T) { oldConfig := config.GetCommonConfig() drivers := []string{common.DefenderDriverMemory} if isDbDefenderSupported() { drivers = append(drivers, common.DefenderDriverProvider) } for _, driver := range drivers { cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Driver = driver cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 cfg.DefenderConfig.ScoreNoAuth = 0 err := common.Initialize(cfg, 0) assert.NoError(t, err) ip := "::1" hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) assert.Len(t, hosts, 0) _, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusNotFound) assert.NoError(t, err) common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventNoLoginTried) hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) assert.Len(t, hosts, 0) common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventUserNotFound) hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) if assert.Len(t, hosts, 1) { host := hosts[0] assert.Empty(t, host.GetBanTime()) assert.Equal(t, 2, host.Score) assert.Equal(t, ip, host.IP) } host, _, err := httpdtest.GetDefenderHostByIP(ip, http.StatusOK) assert.NoError(t, err) assert.Empty(t, host.GetBanTime()) assert.Equal(t, 2, host.Score) common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventUserNotFound) hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) if assert.Len(t, hosts, 1) { host := hosts[0] assert.NotEmpty(t, host.GetBanTime()) assert.Equal(t, 0, host.Score) assert.Equal(t, ip, host.IP) } host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusOK) assert.NoError(t, err) assert.NotEmpty(t, host.GetBanTime()) assert.Equal(t, 0, host.Score) _, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound) assert.NoError(t, err) common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventUserNotFound) common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventUserNotFound) hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) assert.Len(t, hosts, 1) _, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusOK) assert.NoError(t, err) host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound) assert.NoError(t, err) _, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusNotFound) assert.NoError(t, err) host, _, err = httpdtest.GetDefenderHostByIP("invalid_ip", http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.RemoveDefenderHostByIP("invalid_ip", http.StatusBadRequest) assert.NoError(t, err) if driver == common.DefenderDriverProvider { err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Hour))) assert.NoError(t, err) } } err := common.Initialize(oldConfig, 0) require.NoError(t, err) } func TestDefenderAPIErrors(t *testing.T) { if isDbDefenderSupported() { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Driver = common.DefenderDriverProvider err := common.Initialize(cfg, 0) require.NoError(t, err) token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, defenderHosts, nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = common.Initialize(oldConfig, 0) require.NoError(t, err) } } func TestRestoreShares(t *testing.T) { // shares should be restored preserving the UsedTokens, CreatedAt, LastUseAt, UpdatedAt, // and ExpiresAt, so an expired share can be restored while we cannot create an already // expired share user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) share := dataprovider.Share{ ShareID: shortuuid.New(), Name: "share name", Description: "share description", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, Username: user.Username, CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-144 * time.Hour)), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-96 * time.Hour)), LastUseAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-64 * time.Hour)), ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-48 * time.Hour)), MaxTokens: 10, UsedTokens: 8, AllowFrom: []string{"127.0.0.0/8"}, } backupData := dataprovider.BackupData{ Version: dataprovider.DumpVersion, } backupData.Shares = append(backupData.Shares, share) backupContent, err := json.Marshal(backupData) assert.NoError(t, err) _, _, err = httpdtest.LoaddataFromPostBody(backupContent, "0", "0", http.StatusOK) assert.NoError(t, err) shareGet, err := dataprovider.ShareExists(share.ShareID, user.Username) assert.NoError(t, err) assert.Equal(t, share, shareGet) share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-142 * time.Hour)) share.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-92 * time.Hour)) share.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-62 * time.Hour)) share.UsedTokens = 6 backupData.Shares = []dataprovider.Share{share} backupContent, err = json.Marshal(backupData) assert.NoError(t, err) _, _, err = httpdtest.LoaddataFromPostBody(backupContent, "0", "0", http.StatusOK) assert.NoError(t, err) shareGet, err = dataprovider.ShareExists(share.ShareID, user.Username) assert.NoError(t, err) assert.Equal(t, share, shareGet) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestLoaddataFromPostBody(t *testing.T) { mappedPath := filepath.Join(os.TempDir(), "restored_folder") folderName := filepath.Base(mappedPath) role := getTestRole() role.ID = 1 role.Name = "test_restored_role" group := getTestGroup() group.ID = 1 group.Name = "test_group_restored" user := getTestUser() user.ID = 1 user.Username = "test_user_restored" user.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } user.Role = role.Name admin := getTestAdmin() admin.ID = 1 admin.Username = "test_admin_restored" admin.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers, dataprovider.PermAdminDeleteUsers, dataprovider.PermAdminViewUsers} admin.Role = role.Name backupData := dataprovider.BackupData{ Version: dataprovider.DumpVersion, } backupData.Users = append(backupData.Users, user) backupData.Groups = append(backupData.Groups, group) backupData.Admins = append(backupData.Admins, admin) backupData.Roles = append(backupData.Roles, role) backupData.Folders = []vfs.BaseVirtualFolder{ { Name: folderName, MappedPath: mappedPath, UsedQuotaSize: 123, UsedQuotaFiles: 456, LastQuotaUpdate: 789, Users: []string{"user"}, }, { Name: folderName, MappedPath: mappedPath + "1", }, } backupData.APIKeys = append(backupData.APIKeys, dataprovider.APIKey{}) backupData.Shares = append(backupData.Shares, dataprovider.Share{}) backupContent, err := json.Marshal(backupData) assert.NoError(t, err) _, _, err = httpdtest.LoaddataFromPostBody(nil, "0", "0", http.StatusBadRequest) assert.NoError(t, err) _, _, err = httpdtest.LoaddataFromPostBody(backupContent, "a", "0", http.StatusBadRequest) assert.NoError(t, err) _, _, err = httpdtest.LoaddataFromPostBody([]byte("invalid content"), "0", "0", http.StatusBadRequest) assert.NoError(t, err) _, _, err = httpdtest.LoaddataFromPostBody(backupContent, "0", "0", http.StatusInternalServerError) assert.NoError(t, err) keyID := util.GenerateUniqueID() backupData.APIKeys = []dataprovider.APIKey{ { Name: "test key", Scope: dataprovider.APIKeyScopeAdmin, KeyID: keyID, Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), }, } backupData.Shares = []dataprovider.Share{ { ShareID: keyID, Name: keyID, Scope: dataprovider.ShareScopeWrite, Paths: []string{"/"}, Username: user.Username, }, } backupContent, err = json.Marshal(backupData) assert.NoError(t, err) _, resp, err := httpdtest.LoaddataFromPostBody(backupContent, "0", "0", http.StatusOK) assert.NoError(t, err, string(resp)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Name, user.Role) if assert.Len(t, user.Groups, 1) { assert.Equal(t, sdk.GroupTypePrimary, user.Groups[0].Type) assert.Equal(t, group.Name, user.Groups[0].Name) } role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, role.Admins, 1) assert.Len(t, role.Users, 1) _, err = dataprovider.ShareExists(keyID, user.Username) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Name, admin.Role) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, role.Admins, 0) assert.Len(t, role.Users, 0) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) apiKey, _, err := httpdtest.GetAPIKeyByID(keyID, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) assert.NoError(t, err) folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, mappedPath+"1", folder.MappedPath) assert.Equal(t, int64(123), folder.UsedQuotaSize) assert.Equal(t, 456, folder.UsedQuotaFiles) assert.Equal(t, int64(789), folder.LastQuotaUpdate) assert.Len(t, folder.Users, 0) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) } func TestLoaddata(t *testing.T) { err := dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) mappedPath := filepath.Join(os.TempDir(), "restored_folder") folderName := filepath.Base(mappedPath) folderDesc := "restored folder desc" user := getTestUser() user.ID = 1 user.Username = "test_user_restore" user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/vuserpath", }) group := getTestGroup() group.ID = 1 group.Name = "test_group_restore" group.VirtualFolders = append(group.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/vgrouppath", }) role := getTestRole() role.ID = 1 role.Name = "test_role_restore" user.Groups = append(user.Groups, sdk.GroupMapping{ Name: group.Name, Type: sdk.GroupTypePrimary, }) admin := getTestAdmin() admin.ID = 1 admin.Username = "test_admin_restore" admin.Groups = []dataprovider.AdminGroupMapping{ { Name: group.Name, }, } ipListEntry := dataprovider.IPListEntry{ IPOrNet: "172.16.2.4/32", Description: "entry desc", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, Protocols: 3, } apiKey := dataprovider.APIKey{ Name: util.GenerateUniqueID(), Scope: dataprovider.APIKeyScopeAdmin, KeyID: util.GenerateUniqueID(), Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), } share := dataprovider.Share{ ShareID: util.GenerateUniqueID(), Name: util.GenerateUniqueID(), Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, Username: user.Username, } action := dataprovider.BaseEventAction{ ID: 81, Name: "test_restore_action", Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: "https://localhost:4567/action", Username: defaultUsername, Password: kms.NewPlainSecret(defaultPassword), Timeout: 10, SkipTLSVerify: true, Method: http.MethodPost, Body: `{"event":"{{.Event}}","name":"{{.Name}}"}`, }, }, } rule := dataprovider.EventRule{ ID: 100, Name: "test_rule_restore", Description: "", Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"download"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action.Name, }, Order: 1, }, }, } configs := dataprovider.Configs{ SFTPD: &dataprovider.SFTPDConfigs{ HostKeyAlgos: []string{ssh.KeyAlgoRSA, ssh.CertAlgoRSAv01}, PublicKeyAlgos: []string{ssh.InsecureKeyAlgoDSA}, //nolint:staticcheck }, SMTP: &dataprovider.SMTPConfigs{ Host: "mail.example.com", Port: 587, From: "from@example.net", }, } backupData := dataprovider.BackupData{ Version: 14, } backupData.Configs = &configs backupData.Users = append(backupData.Users, user) backupData.Roles = append(backupData.Roles, role) backupData.Groups = append(backupData.Groups, group) backupData.Admins = append(backupData.Admins, admin) backupData.Folders = []vfs.BaseVirtualFolder{ { Name: folderName, MappedPath: mappedPath + "1", UsedQuotaSize: 123, UsedQuotaFiles: 456, LastQuotaUpdate: 789, Users: []string{"user"}, }, { MappedPath: mappedPath, Name: folderName, Description: folderDesc, }, } backupData.APIKeys = append(backupData.APIKeys, apiKey) backupData.Shares = append(backupData.Shares, share) backupData.EventActions = append(backupData.EventActions, action) backupData.EventRules = append(backupData.EventRules, rule) backupData.IPLists = append(backupData.IPLists, ipListEntry) backupContent, err := json.Marshal(backupData) assert.NoError(t, err) backupFilePath := filepath.Join(backupsPath, "backup.json") err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "a", "", http.StatusBadRequest) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "", "a", http.StatusBadRequest) assert.NoError(t, err) _, _, err = httpdtest.Loaddata("backup.json", "1", "", http.StatusBadRequest) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath+"a", "1", "", http.StatusBadRequest) assert.NoError(t, err) if runtime.GOOS != osWindows { err = os.Chmod(backupFilePath, 0111) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "1", "", http.StatusBadRequest) assert.NoError(t, err) err = os.Chmod(backupFilePath, 0644) assert.NoError(t, err) } // add objects from backup _, resp, err := httpdtest.Loaddata(backupFilePath, "1", "", http.StatusOK) assert.NoError(t, err, string(resp)) // update from backup _, _, err = httpdtest.Loaddata(backupFilePath, "2", "", http.StatusOK) assert.NoError(t, err) configsGet, err := dataprovider.GetConfigs() assert.NoError(t, err) assert.Equal(t, configs.SMTP, configsGet.SMTP) assert.Equal(t, []string{ssh.KeyAlgoRSA}, configsGet.SFTPD.HostKeyAlgos) assert.Equal(t, []string{ssh.InsecureKeyAlgoDSA}, configsGet.SFTPD.PublicKeyAlgos) //nolint:staticcheck assert.Len(t, configsGet.SFTPD.KexAlgorithms, 0) assert.Len(t, configsGet.SFTPD.Ciphers, 0) assert.Len(t, configsGet.SFTPD.MACs, 0) assert.Greater(t, configsGet.UpdatedAt, int64(0)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.VirtualFolders, 1) assert.Len(t, user.Groups, 1) _, err = dataprovider.ShareExists(share.ShareID, user.Username) assert.NoError(t, err) role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group.VirtualFolders, 1) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.Len(t, admin.Groups, 1) apiKey, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusOK) assert.NoError(t, err) action, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) entry, _, err := httpdtest.GetIPListEntry(ipListEntry.IPOrNet, ipListEntry.Type, http.StatusOK) assert.NoError(t, err) assert.Greater(t, entry.CreatedAt, int64(0)) assert.Greater(t, entry.UpdatedAt, int64(0)) assert.Equal(t, ipListEntry.Description, entry.Description) assert.Equal(t, ipListEntry.Protocols, entry.Protocols) assert.Equal(t, ipListEntry.Mode, entry.Mode) rule, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, rule.Status) if assert.Len(t, rule.Actions, 1) { if assert.NotNil(t, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password) { assert.Equal(t, sdkkms.SecretStatusSecretBox, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetPayload()) assert.Empty(t, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetKey()) assert.Empty(t, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetAdditionalData()) } } response, _, err := httpdtest.Dumpdata("", "1", "0", http.StatusOK) assert.NoError(t, err) var dumpedData dataprovider.BackupData data, err := json.Marshal(response) assert.NoError(t, err) err = json.Unmarshal(data, &dumpedData) assert.NoError(t, err) found := false if assert.GreaterOrEqual(t, len(dumpedData.Users), 1) { for _, u := range dumpedData.Users { if u.Username == user.Username { found = true assert.Equal(t, len(user.VirtualFolders), len(u.VirtualFolders)) assert.Equal(t, len(user.Groups), len(u.Groups)) } } } assert.True(t, found) found = false if assert.GreaterOrEqual(t, len(dumpedData.Admins), 1) { for _, a := range dumpedData.Admins { if a.Username == admin.Username { found = true assert.Equal(t, len(admin.Groups), len(a.Groups)) } } } assert.True(t, found) if assert.Len(t, dumpedData.Groups, 1) { assert.Equal(t, len(group.VirtualFolders), len(dumpedData.Groups[0].VirtualFolders)) } if assert.Len(t, dumpedData.EventActions, 1) { assert.Equal(t, action.Name, dumpedData.EventActions[0].Name) } if assert.Len(t, dumpedData.EventRules, 1) { assert.Equal(t, rule.Name, dumpedData.EventRules[0].Name) assert.Len(t, dumpedData.EventRules[0].Actions, 1) } found = false for _, r := range dumpedData.Roles { if r.Name == role.Name { found = true } } assert.True(t, found) folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, mappedPath, folder.MappedPath) assert.Equal(t, int64(123), folder.UsedQuotaSize) assert.Equal(t, 456, folder.UsedQuotaFiles) assert.Equal(t, int64(789), folder.LastQuotaUpdate) assert.Equal(t, folderDesc, folder.Description) assert.Len(t, folder.Users, 1) response, _, err = httpdtest.Dumpdata("", "1", "0", http.StatusOK, dataprovider.DumpScopeUsers) assert.NoError(t, err) dumpedData = dataprovider.BackupData{} data, err = json.Marshal(response) assert.NoError(t, err) err = json.Unmarshal(data, &dumpedData) assert.NoError(t, err) assert.Greater(t, len(dumpedData.Users), 0) assert.Len(t, dumpedData.Admins, 0) assert.Len(t, dumpedData.Folders, 0) assert.Len(t, dumpedData.Groups, 0) assert.Len(t, dumpedData.Roles, 0) assert.Len(t, dumpedData.EventRules, 0) assert.Len(t, dumpedData.EventActions, 0) assert.Len(t, dumpedData.IPLists, 0) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) assert.NoError(t, err) err = os.Remove(backupFilePath) assert.NoError(t, err) err = createTestFile(backupFilePath, 20*1048576+1) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "1", "0", http.StatusBadRequest) assert.NoError(t, err) err = os.Remove(backupFilePath) assert.NoError(t, err) err = createTestFile(backupFilePath, 65535) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "1", "0", http.StatusBadRequest) assert.NoError(t, err) err = os.Remove(backupFilePath) assert.NoError(t, err) err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestLoaddataConvertActions(t *testing.T) { a1 := dataprovider.BaseEventAction{ Name: xid.New().String(), Type: dataprovider.ActionTypeEmail, Options: dataprovider.BaseEventActionOptions{ EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"failure@example.com"}, Subject: `Failed "{{Event}}" from "{{Name}}"`, Body: "Object name: {{ObjectName}} object type: {{ObjectType}}, IP: {{IP}}", }, }, } a2 := dataprovider.BaseEventAction{ Name: xid.New().String(), Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, Renames: []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/{{VirtualDirPath}}/{{ObjectName}}", Value: "/{{ObjectName}}_renamed", }, }, }, }, }, } backupData := dataprovider.BackupData{ EventActions: []dataprovider.BaseEventAction{a1, a2}, Version: 16, } backupContent, err := json.Marshal(backupData) assert.NoError(t, err) backupFilePath := filepath.Join(backupsPath, "backup.json") err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, resp, err := httpdtest.Loaddata(backupFilePath, "1", "2", http.StatusOK) assert.NoError(t, err, string(resp)) // Check that actions are migrated. action1, _, err := httpdtest.GetEventActionByName(a1.Name, http.StatusOK) assert.NoError(t, err) action2, _, err := httpdtest.GetEventActionByName(a2.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, `Failed "{{.Event}}" from "{{.Name}}"`, action1.Options.EmailConfig.Subject) assert.Equal(t, `Object name: {{.ObjectName}} object type: {{.ObjectType}}, IP: {{.IP}}`, action1.Options.EmailConfig.Body) assert.Equal(t, `/{{.VirtualDirPath}}/{{.ObjectName}}`, action2.Options.FsConfig.Renames[0].Key) assert.Equal(t, `/{{.ObjectName}}_renamed`, action2.Options.FsConfig.Renames[0].Value) // If we restore a backup from the current version actions are not migrated. backupData = dataprovider.BackupData{ EventActions: []dataprovider.BaseEventAction{a1, a2}, Version: dataprovider.DumpVersion, } backupContent, err = json.Marshal(backupData) assert.NoError(t, err) backupFilePath = filepath.Join(backupsPath, "backup.json") err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, resp, err = httpdtest.Loaddata(backupFilePath, "1", "2", http.StatusOK) assert.NoError(t, err, string(resp)) action1, _, err = httpdtest.GetEventActionByName(a1.Name, http.StatusOK) assert.NoError(t, err) action2, _, err = httpdtest.GetEventActionByName(a2.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, `Failed "{{Event}}" from "{{Name}}"`, action1.Options.EmailConfig.Subject) assert.Equal(t, `Object name: {{ObjectName}} object type: {{ObjectType}}, IP: {{IP}}`, action1.Options.EmailConfig.Body) assert.Equal(t, `/{{VirtualDirPath}}/{{ObjectName}}`, action2.Options.FsConfig.Renames[0].Key) assert.Equal(t, `/{{ObjectName}}_renamed`, action2.Options.FsConfig.Renames[0].Value) // Cleanup. _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) assert.NoError(t, err) actions, _, err := httpdtest.GetEventActions(0, 0, http.StatusOK) assert.NoError(t, err) assert.Len(t, actions, 0) } func TestLoaddataMode(t *testing.T) { err := dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) mappedPath := filepath.Join(os.TempDir(), "restored_fold") folderName := filepath.Base(mappedPath) configs := dataprovider.Configs{ SFTPD: &dataprovider.SFTPDConfigs{ PublicKeyAlgos: []string{ssh.KeyAlgoRSA}, }, } role := getTestRole() role.ID = 1 role.Name = "test_role_load" role.Description = "" user := getTestUser() user.ID = 1 user.Username = "test_user_restore" user.Role = role.Name group := getTestGroup() group.ID = 1 group.Name = "test_group_restore" user.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } admin := getTestAdmin() admin.ID = 1 admin.Username = "test_admin_restore" apiKey := dataprovider.APIKey{ Name: util.GenerateUniqueID(), Scope: dataprovider.APIKeyScopeAdmin, KeyID: util.GenerateUniqueID(), Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), Description: "desc", } share := dataprovider.Share{ ShareID: util.GenerateUniqueID(), Name: util.GenerateUniqueID(), Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, Username: user.Username, } action := dataprovider.BaseEventAction{ ID: 81, Name: "test_restore_action_data_mode", Description: "action desc", Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: "https://localhost:4567/mode", Username: defaultUsername, Password: kms.NewPlainSecret(defaultPassword), Timeout: 10, SkipTLSVerify: true, Method: http.MethodPost, Body: `{"event":"{{.Event}}","name":"{{.Name}}"}`, }, }, } rule := dataprovider.EventRule{ ID: 100, Name: "test_rule_restore_data_mode", Description: "rule desc", Trigger: dataprovider.EventTriggerFsEvent, Conditions: dataprovider.EventConditions{ FsEvents: []string{"mkdir"}, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action.Name, }, Order: 1, }, }, } ipListEntry := dataprovider.IPListEntry{ IPOrNet: "10.8.3.9/32", Description: "note", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, Protocols: 7, } backupData := dataprovider.BackupData{ Version: dataprovider.DumpVersion, } backupData.Configs = &configs backupData.Users = append(backupData.Users, user) backupData.Groups = append(backupData.Groups, group) backupData.Admins = append(backupData.Admins, admin) backupData.EventActions = append(backupData.EventActions, action) backupData.EventRules = append(backupData.EventRules, rule) backupData.Roles = append(backupData.Roles, role) backupData.Folders = []vfs.BaseVirtualFolder{ { Name: folderName, MappedPath: mappedPath, UsedQuotaSize: 123, UsedQuotaFiles: 456, LastQuotaUpdate: 789, Users: []string{"user"}, }, { MappedPath: mappedPath + "1", Name: folderName, }, } backupData.APIKeys = append(backupData.APIKeys, apiKey) backupData.Shares = append(backupData.Shares, share) backupData.IPLists = append(backupData.IPLists, ipListEntry) backupContent, _ := json.Marshal(backupData) backupFilePath := filepath.Join(backupsPath, "backup.json") err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) _, _, err = httpdtest.Loaddata(backupFilePath, "0", "0", http.StatusOK) assert.NoError(t, err) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, mappedPath+"1", folder.MappedPath) assert.Equal(t, int64(123), folder.UsedQuotaSize) assert.Equal(t, 456, folder.UsedQuotaFiles) assert.Equal(t, int64(789), folder.LastQuotaUpdate) assert.Len(t, folder.Users, 0) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Name, user.Role) oldUploadBandwidth := user.UploadBandwidth user.UploadBandwidth = oldUploadBandwidth + 128 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, role.Users, 1) assert.Len(t, role.Admins, 0) assert.Empty(t, role.Description) role.Description = "role desc" _, _, err = httpdtest.UpdateRole(role, http.StatusOK) assert.NoError(t, err) role.Description = "" group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, group.Users, 1) oldGroupDesc := group.Description group.Description = "new group description" group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) oldInfo := admin.AdditionalInfo oldDesc := admin.Description admin.AdditionalInfo = "newInfo" admin.Description = "newDesc" admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) apiKey, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusOK) assert.NoError(t, err) oldAPIKeyDesc := apiKey.Description apiKey.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now()) apiKey.Description = "new desc" apiKey, _, err = httpdtest.UpdateAPIKey(apiKey, http.StatusOK) assert.NoError(t, err) share.Description = "test desc" err = dataprovider.UpdateShare(&share, "", "", "") assert.NoError(t, err) action, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) oldActionDesc := action.Description action.Description = "new action description" action, _, err = httpdtest.UpdateEventAction(action, http.StatusOK) assert.NoError(t, err) rule, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, rule.Status) oldRuleDesc := rule.Description rule.Description = "new rule description" rule, _, err = httpdtest.UpdateEventRule(rule, http.StatusOK) assert.NoError(t, err) entry, _, err := httpdtest.GetIPListEntry(ipListEntry.IPOrNet, ipListEntry.Type, http.StatusOK) assert.NoError(t, err) oldEntryDesc := entry.Description entry.Description = "new note" entry, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusOK) assert.NoError(t, err) configs.SFTPD.PublicKeyAlgos = append(configs.SFTPD.PublicKeyAlgos, ssh.InsecureKeyAlgoDSA) //nolint:staticcheck err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) backupData.Configs = &configs backupData.Folders = []vfs.BaseVirtualFolder{ { MappedPath: mappedPath, Name: folderName, }, } _, _, err = httpdtest.Loaddata(backupFilePath, "0", "1", http.StatusOK) assert.NoError(t, err) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Len(t, configs.SFTPD.PublicKeyAlgos, 2) group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) assert.NotEqual(t, oldGroupDesc, group.Description) folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, mappedPath+"1", folder.MappedPath) assert.Equal(t, int64(123), folder.UsedQuotaSize) assert.Equal(t, 456, folder.UsedQuotaFiles) assert.Equal(t, int64(789), folder.LastQuotaUpdate) assert.Len(t, folder.Users, 0) action, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.NotEqual(t, oldActionDesc, action.Description) rule, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) assert.NoError(t, err) assert.NotEqual(t, oldRuleDesc, rule.Description) entry, _, err = httpdtest.GetIPListEntry(ipListEntry.IPOrNet, ipListEntry.Type, http.StatusOK) assert.NoError(t, err) assert.NotEqual(t, oldEntryDesc, entry.Description) c := common.NewBaseConnection("connID", common.ProtocolFTP, "", "", user) fakeConn := &fakeConnection{ BaseConnection: c, } err = common.Connections.Add(fakeConn) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 1) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.NotEqual(t, oldUploadBandwidth, user.UploadBandwidth) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.NotEqual(t, oldInfo, admin.AdditionalInfo) assert.NotEqual(t, oldDesc, admin.Description) apiKey, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusOK) assert.NoError(t, err) assert.NotEqual(t, int64(0), apiKey.ExpiresAt) assert.NotEqual(t, oldAPIKeyDesc, apiKey.Description) share, err = dataprovider.ShareExists(share.ShareID, user.Username) assert.NoError(t, err) assert.NotEmpty(t, share.Description) role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) assert.Len(t, role.Users, 1) assert.Len(t, role.Admins, 0) assert.NotEmpty(t, role.Description) _, _, err = httpdtest.Loaddata(backupFilePath, "0", "2", http.StatusOK) assert.NoError(t, err) // mode 2 will update the user and close the previous connection assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, oldUploadBandwidth, user.UploadBandwidth) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) // the group is referenced _, err = httpdtest.RemoveGroup(group, http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) assert.NoError(t, err) err = os.Remove(backupFilePath) assert.NoError(t, err) err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestRateLimiter(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.RateLimitersConfig = []common.RateLimiterConfig{ { Average: 1, Period: 1000, Burst: 1, Type: 1, Protocols: []string{common.ProtocolHTTP}, }, } err := common.Initialize(cfg, 0) assert.NoError(t, err) client := &http.Client{ Timeout: 5 * time.Second, } resp, err := client.Get(httpBaseURL + healthzPath) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) resp, err = client.Get(httpBaseURL + healthzPath) assert.NoError(t, err) assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) assert.Equal(t, "1", resp.Header.Get("Retry-After")) assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) err = resp.Body.Close() assert.NoError(t, err) resp, err = client.Get(httpBaseURL + webLoginPath) assert.NoError(t, err) assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) assert.Equal(t, "1", resp.Header.Get("Retry-After")) assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) err = resp.Body.Close() assert.NoError(t, err) resp, err = client.Get(httpBaseURL + webClientLoginPath) assert.NoError(t, err) assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) assert.Equal(t, "1", resp.Header.Get("Retry-After")) assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) err = resp.Body.Close() assert.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestHTTPSConnection(t *testing.T) { client := &http.Client{ Timeout: 5 * time.Second, } resp, err := client.Get("https://localhost:8443" + healthzPath) if assert.Error(t, err) { if !strings.Contains(err.Error(), "certificate is not valid") && !strings.Contains(err.Error(), "certificate signed by unknown authority") && !strings.Contains(err.Error(), "certificate is not standards compliant") { assert.Fail(t, err.Error()) } } else { resp.Body.Close() } } // test using mock http server func TestBasicUserHandlingMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, err := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) user.MaxSessions = 10 user.UploadBandwidth = 128 user.Permissions["/"] = []string{dataprovider.PermAny, dataprovider.PermDelete, dataprovider.PermDownload} userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPut, userPath+"/"+user.Username, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, userPath+"/"+user.Username, nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updatedUser dataprovider.User err = render.DecodeJSON(rr.Body, &updatedUser) assert.NoError(t, err) assert.Equal(t, user.MaxSessions, updatedUser.MaxSessions) assert.Equal(t, user.UploadBandwidth, updatedUser.UploadBandwidth) assert.Equal(t, 1, len(updatedUser.Permissions["/"])) assert.True(t, slices.Contains(updatedUser.Permissions["/"], dataprovider.PermAny)) req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+user.Username, nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestAddUserNoUsernameMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() user.Username = "" userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestAddUserInvalidHomeDirMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() user.HomeDir = "relative_path" userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestAddUserInvalidPermsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() user.Permissions["/"] = []string{} userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestAddFolderInvalidJsonMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer([]byte("invalid json"))) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestAddEventRuleInvalidJsonMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, eventActionsPath, bytes.NewBuffer([]byte("invalid json"))) require.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodPost, eventRulesPath, bytes.NewBuffer([]byte("invalid json"))) require.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestAddRoleInvalidJsonMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, rolesPath, bytes.NewBuffer([]byte("{"))) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestIPListEntriesErrorsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, ipListsPath+"/a/b", nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "invalid list type") req, err = http.NewRequest(http.MethodGet, ipListsPath+"/invalid", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "invalid list type") reqBody := bytes.NewBuffer([]byte("{")) req, err = http.NewRequest(http.MethodPost, ipListsPath+"/2", reqBody) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) entry := dataprovider.IPListEntry{ IPOrNet: "172.120.1.1/32", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 0, } _, _, err = httpdtest.AddIPListEntry(entry, http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(ipListsPath, "1", url.PathEscape(entry.IPOrNet)), reqBody) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) assert.NoError(t, err) } func TestRoleErrorsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) reqBody := bytes.NewBuffer([]byte("{")) req, err := http.NewRequest(http.MethodGet, rolesPath+"?limit=a", nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) role, _, err := httpdtest.AddRole(getTestRole(), http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(rolesPath, role.Name), reqBody) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodPut, path.Join(rolesPath, "missing_role"), reqBody) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusNotFound) assert.NoError(t, err) } func TestEventRuleErrorsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) reqBody := bytes.NewBuffer([]byte("invalid json body")) req, err := http.NewRequest(http.MethodGet, eventActionsPath+"?limit=b", nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, eventRulesPath+"?limit=c", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) a := dataprovider.BaseEventAction{ Name: "action_name", Description: "test description", Type: dataprovider.ActionTypeBackup, Options: dataprovider.BaseEventActionOptions{}, } action, _, err := httpdtest.AddEventAction(a, http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(eventActionsPath, action.Name), reqBody) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) r := dataprovider.EventRule{ Name: "test_event_rule", Trigger: dataprovider.EventTriggerSchedule, Conditions: dataprovider.EventConditions{ Schedules: []dataprovider.Schedule{ { Hours: "2", DayOfWeek: "*", DayOfMonth: "*", Month: "*", }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action.Name, }, Order: 1, }, }, } rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(eventRulesPath, rule.Name), reqBody) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) rule.Actions[0].Name = "misssing action name" asJSON, err := json.Marshal(rule) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(eventRulesPath, rule.Name), bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveEventAction(action, http.StatusOK) assert.NoError(t, err) } func TestGroupErrorsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) reqBody := bytes.NewBuffer([]byte("not a json string")) req, err := http.NewRequest(http.MethodPost, groupPath, reqBody) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, groupPath+"?limit=d", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) group, _, err := httpdtest.AddGroup(getTestGroup(), http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(groupPath, group.Name), reqBody) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestUpdateFolderInvalidJsonMock(t *testing.T) { folder := vfs.BaseVirtualFolder{ Name: "name", MappedPath: filepath.Clean(os.TempDir()), } folder, resp, err := httpdtest.AddFolder(folder, http.StatusCreated) assert.NoError(t, err, string(resp)) token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPut, path.Join(folderPath, folder.Name), bytes.NewBuffer([]byte("not a json"))) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) } func TestAddUserInvalidJsonMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer([]byte("invalid json"))) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestAddAdminInvalidJsonMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer([]byte("..."))) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestAddAdminNoPasswordMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) admin := getTestAdmin() admin.Password = "" asJSON, err := json.Marshal(admin) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "please set a password") } func TestAdminTwoFactorLogin(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) admin1 := getTestAdmin() admin1.Username = altAdminUsername + "1" admin1.Password = altAdminPassword var permissions []string for _, p := range admin1.GetValidPerms() { if p != dataprovider.PermAdminAny && p != dataprovider.PermAdminDisableMFA { permissions = append(permissions, p) } } admin1.Permissions = permissions admin1, _, err = httpdtest.AddAdmin(admin1, http.StatusCreated) assert.NoError(t, err) // enable two factor authentication configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], admin.Username) assert.NoError(t, err) altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) adminTOTPConfig := dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), } asJSON, err := json.Marshal(adminTOTPConfig) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, altToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var recCodes []recoveryCode err = json.Unmarshal(rr.Body.Bytes(), &recCodes) assert.NoError(t, err) assert.Len(t, recCodes, 12) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, admin.Filters.RecoveryCodes, 12) for _, c := range admin.Filters.RecoveryCodes { assert.Empty(t, c.Secret.GetAdditionalData()) assert.Empty(t, c.Secret.GetKey()) assert.Equal(t, sdkkms.SecretStatusSecretBox, c.Secret.GetStatus()) assert.NotEmpty(t, c.Secret.GetPayload()) } webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorRecoveryPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err := getCookieFromResponse(rr) assert.NoError(t, err) // without a cookie req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorRecoveryPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorRecoveryPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // any other page will be redirected to the two factor auth page req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) // a partial token cannot be used for user pages req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) form = make(url.Values) form.Set("passcode", passcode) // no csrf req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("passcode", "invalid_passcode") req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form.Set("passcode", "") req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form.Set("passcode", passcode) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webUsersPath, rr.Header().Get("Location")) // the same cookie cannot be reused req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusNotFound, rr.Code) // get a new cookie and login using a recovery code loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) form = make(url.Values) recoveryCode := recCodes[0].Code form.Set("recovery_code", recoveryCode) // no csrf req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("recovery_code", "") req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form.Set("recovery_code", recoveryCode) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webUsersPath, rr.Header().Get("Location")) authenticatedCookie, err := getCookieFromResponse(rr) assert.NoError(t, err) //render MFA page req, err = http.NewRequest(http.MethodGet, webAdminMFAPath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, authenticatedCookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // check that the recovery code was marked as used req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) recCodes = nil err = json.Unmarshal(rr.Body.Bytes(), &recCodes) assert.NoError(t, err) assert.Len(t, recCodes, 12) found := false for _, rc := range recCodes { if rc.Code == recoveryCode { found = true assert.True(t, rc.Used) } else { assert.False(t, rc.Used) } } assert.True(t, found) // the same recovery code cannot be reused loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recoveryCode) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form.Set("recovery_code", "invalid_recovery_code") req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) // disable TOTP altToken1, err := getJWTAPITokenFromTestServer(admin1.Username, altAdminPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, altToken1) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "two-factor authentication is not enabled") csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recoveryCode) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18n2FADisabled) form.Set("passcode", passcode) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18n2FADisabled) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin1, http.StatusOK) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) req, err = http.NewRequest(http.MethodGet, webAdminMFAPath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, authenticatedCookie) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) } func TestAdminTOTP(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword // TOTPConfig will be ignored on add admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: "config", Secret: kms.NewEmptySecret(), } asJSON, err := json.Marshal(admin) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.TOTPConfig.Enabled) assert.Len(t, admin.Filters.RecoveryCodes, 0) altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, adminTOTPConfigsPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var configs []mfa.TOTPConfig err = json.Unmarshal(rr.Body.Bytes(), &configs) assert.NoError(t, err, rr.Body.String()) assert.Len(t, configs, len(mfa.GetAvailableTOTPConfigs())) totpConfig := configs[0] totpReq := generateTOTPRequest{ ConfigName: totpConfig.Name, } asJSON, err = json.Marshal(totpReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, adminTOTPGeneratePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var totpGenResp generateTOTPResponse err = json.Unmarshal(rr.Body.Bytes(), &totpGenResp) assert.NoError(t, err) assert.NotEmpty(t, totpGenResp.Secret) assert.NotEmpty(t, totpGenResp.QRCode) passcode, err := generateTOTPPasscode(totpGenResp.Secret) assert.NoError(t, err) validateReq := validateTOTPRequest{ ConfigName: totpGenResp.ConfigName, Passcode: passcode, Secret: totpGenResp.Secret, } asJSON, err = json.Marshal(validateReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, adminTOTPValidatePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // the same passcode cannot be reused req, err = http.NewRequest(http.MethodPost, adminTOTPValidatePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "this passcode was already used") adminTOTPConfig := dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: totpGenResp.ConfigName, Secret: kms.NewPlainSecret(totpGenResp.Secret), } asJSON, err = json.Marshal(adminTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) assert.Equal(t, totpGenResp.ConfigName, admin.Filters.TOTPConfig.ConfigName) assert.Empty(t, admin.Filters.TOTPConfig.Secret.GetKey()) assert.Empty(t, admin.Filters.TOTPConfig.Secret.GetAdditionalData()) assert.NotEmpty(t, admin.Filters.TOTPConfig.Secret.GetPayload()) assert.Equal(t, sdkkms.SecretStatusSecretBox, admin.Filters.TOTPConfig.Secret.GetStatus()) admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: false, ConfigName: util.GenerateUniqueID(), Secret: kms.NewEmptySecret(), } admin.Filters.RecoveryCodes = []dataprovider.RecoveryCode{ { Secret: kms.NewEmptySecret(), }, } admin, resp, err := httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err, string(resp)) assert.True(t, admin.Filters.TOTPConfig.Enabled) assert.Len(t, admin.Filters.RecoveryCodes, 12) // if we use token we cannot get or generate recovery codes req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // now the same but with altToken req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var recCodes []recoveryCode err = json.Unmarshal(rr.Body.Bytes(), &recCodes) assert.NoError(t, err) assert.Len(t, recCodes, 12) // regenerate recovery codes req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // check that recovery codes are different req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var newRecCodes []recoveryCode err = json.Unmarshal(rr.Body.Bytes(), &newRecCodes) assert.NoError(t, err) assert.Len(t, newRecCodes, 12) assert.NotEqual(t, recCodes, newRecCodes) // disable 2FA, the update admin API should not work admin.Filters.TOTPConfig.Enabled = false admin.Filters.RecoveryCodes = nil admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) assert.Equal(t, altAdminUsername, admin.Username) assert.True(t, admin.Filters.TOTPConfig.Enabled) assert.Len(t, admin.Filters.RecoveryCodes, 12) // use the dedicated API req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.TOTPConfig.Enabled) assert.Len(t, admin.Filters.RecoveryCodes, 0) req, _ = http.NewRequest(http.MethodDelete, path.Join(adminPath, altAdminUsername), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestChangeAdminPwdInvalidJsonMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPut, adminPwdPath, bytes.NewBuffer([]byte("{"))) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) } func TestSMTPConfig(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 3525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) smtpTestURL := path.Join(webConfigsPath, "smtp", "test") tokenHeader := "X-CSRF-TOKEN" webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer([]byte("{"))) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req.Header.Set(tokenHeader, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) testReq := make(map[string]any) testReq["host"] = smtpCfg.Host testReq["port"] = 3525 testReq["from"] = "from@example.com" asJSON, err := json.Marshal(testReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set(tokenHeader, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) testReq["recipient"] = "example@example.com" asJSON, err = json.Marshal(testReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set(tokenHeader, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) configs := dataprovider.Configs{ SMTP: &dataprovider.SMTPConfigs{ Host: "127.0.0.1", Port: 3535, User: "user@example.com", Password: kms.NewPlainSecret(defaultPassword), }, } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) testReq["password"] = redactedSecret asJSON, err = json.Marshal(testReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set(tokenHeader, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) assert.Contains(t, rr.Body.String(), "server does not support SMTP AUTH") testReq["password"] = "" testReq["auth_type"] = 3 testReq["oauth2"] = smtp.OAuth2Config{ ClientSecret: redactedSecret, RefreshToken: redactedSecret, } asJSON, err = json.Marshal(testReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set(tokenHeader, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "smtp oauth2: client id is required") err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) } func TestOAuth2TokenRequest(t *testing.T) { tokenHeader := "X-CSRF-TOKEN" webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer([]byte("{"))) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req.Header.Set(tokenHeader, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) testReq := make(map[string]any) testReq["client_secret"] = redactedSecret asJSON, err := json.Marshal(testReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set(tokenHeader, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "base redirect url is required") testReq["base_redirect_url"] = "http://localhost:8081" asJSON, err = json.Marshal(testReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set(tokenHeader, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestMFAPermission(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientMFAPath, nil) assert.NoError(t, err) req.RequestURI = webClientMFAPath setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user.Filters.WebClient = []string{sdk.WebClientMFADisabled} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientMFAPath, nil) assert.NoError(t, err) req.RequestURI = webClientMFAPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebUserTwoFactorLogin(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) // enable two factor authentication configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) adminToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolHTTP}, } asJSON, err := json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var recCodes []recoveryCode err = json.Unmarshal(rr.Body.Bytes(), &recCodes) assert.NoError(t, err) assert.Len(t, recCodes, 12) user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Filters.RecoveryCodes, 12) for _, c := range user.Filters.RecoveryCodes { assert.Empty(t, c.Secret.GetAdditionalData()) assert.Empty(t, c.Secret.GetKey()) assert.Equal(t, sdkkms.SecretStatusSecretBox, c.Secret.GetStatus()) assert.NotEmpty(t, c.Secret.GetPayload()) } req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webClientTwoFactorRecoveryPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(defaultUsername, defaultPassword, csrfToken) // CSRF verification fails if there is no cookie req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err := getCookieFromResponse(rr) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // without a cookie req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // invalid IP address req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = "6.7.8.9:4567" rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webClientTwoFactorRecoveryPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // any other page will be redirected to the two factor auth page req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) // a partial token cannot be used for admin pages req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) form = make(url.Values) form.Set("passcode", passcode) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorPath, cookie) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("passcode", "invalid_user_passcode") req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form.Set("passcode", "") req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form.Set("passcode", passcode) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) // the same cookie cannot be reused req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusNotFound, rr.Code) // get a new cookie and login using a recovery code loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) form = make(url.Values) recoveryCode := recCodes[0].Code form.Set("recovery_code", recoveryCode) // no csrf req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("recovery_code", "") req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form.Set("recovery_code", recoveryCode) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) authenticatedCookie, err := getCookieFromResponse(rr) assert.NoError(t, err) //render MFA page req, err = http.NewRequest(http.MethodGet, webClientMFAPath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, authenticatedCookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // get MFA qrcode req, err = http.NewRequest(http.MethodGet, path.Join(webClientMFAPath, "qrcode?url="+url.QueryEscape(key.URL())), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, authenticatedCookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "image/png", rr.Header().Get("Content-Type")) // invalid MFA url req, err = http.NewRequest(http.MethodGet, path.Join(webClientMFAPath, "qrcode?url="+url.QueryEscape("http://foo\x7f.eu")), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, authenticatedCookie) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) // check that the recovery code was marked as used req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) recCodes = nil err = json.Unmarshal(rr.Body.Bytes(), &recCodes) assert.NoError(t, err) assert.Len(t, recCodes, 12) found := false for _, rc := range recCodes { if rc.Code == recoveryCode { found = true assert.True(t, rc.Used) } else { assert.False(t, rc.Used) } } assert.True(t, found) // the same recovery code cannot be reused loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recoveryCode) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form.Set("recovery_code", "invalid_user_recovery_code") req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) // disable TOTP req, err = http.NewRequest(http.MethodPut, userPath+"/"+user.Username+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodPut, userPath+"/"+user.Username+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "two-factor authentication is not enabled") csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recoveryCode) form.Set("passcode", passcode) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18n2FADisabled) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18n2FADisabled) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) req, err = http.NewRequest(http.MethodGet, webClientMFAPath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, authenticatedCookie) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) } func TestWebUserTwoFactoryLoginRedirect(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolHTTP}, } asJSON, err := json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(defaultUsername, defaultPassword, csrfToken) uri := webClientFilesPath + "?path=%2F" loginURI := webClientLoginPath + "?next=" + url.QueryEscape(uri) expectedURI := webClientTwoFactorPath + "?next=" + url.QueryEscape(uri) req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = loginURI setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, expectedURI, rr.Header().Get("Location")) cookie, err := getCookieFromResponse(rr) assert.NoError(t, err) // test unsafe redirects loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) externalURI := webClientLoginPath + "?next=" + url.QueryEscape("https://example.com") req, err = http.NewRequest(http.MethodPost, externalURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = externalURI setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) internalURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientMFAPath) req, err = http.NewRequest(http.MethodPost, internalURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = internalURI setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) // render two factor page req, err = http.NewRequest(http.MethodGet, expectedURI, nil) assert.NoError(t, err) req.RequestURI = expectedURI setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", expectedURI)) // login with the passcode csrfToken, err = getCSRFTokenFromInternalPageMock(expectedURI, cookie) assert.NoError(t, err) passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) form = make(url.Values) form.Set("passcode", passcode) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, expectedURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.RequestURI = expectedURI req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, uri, rr.Header().Get("Location")) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSearchEvents(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, fsEventsPath+"?limit=10&order=ASC&fs_provider=0", nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) events := make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &events) assert.NoError(t, err) if assert.Len(t, events, 1) { ev := events[0] for _, field := range []string{"id", "timestamp", "action", "username", "fs_path", "status", "protocol", "ip", "session_id", "fs_provider", "bucket", "endpoint", "open_flags", "role", "instance_id"} { _, ok := ev[field] assert.True(t, ok, field) } } req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?limit=10&order=ASC&role=role1", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) events = nil err = json.Unmarshal(rr.Body.Bytes(), &events) assert.NoError(t, err) assert.Len(t, events, 1) // CSV export req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?limit=10&order=ASC&csv_export=true", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "text/csv", rr.Header().Get("Content-Type")) // the test eventsearcher plugin returns error if start_timestamp < 0 req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?start_timestamp=-1&end_timestamp=123456&statuses=1,2", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) // CSV export with error exportFunc := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?start_timestamp=-2&csv_export=true", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) } exportFunc() req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?limit=e", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, providerEventsPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) events = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &events) assert.NoError(t, err) if assert.Len(t, events, 1) { ev := events[0] for _, field := range []string{"id", "timestamp", "action", "username", "object_type", "object_name", "object_data", "role", "instance_id"} { _, ok := ev[field] assert.True(t, ok, field) } } req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?omit_object_data=true", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) events = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &events) assert.NoError(t, err) if assert.Len(t, events, 1) { ev := events[0] field := "object_data" _, ok := ev[field] assert.False(t, ok, field) } // CSV export req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?csv_export=true", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "text/csv", rr.Header().Get("Content-Type")) // the test eventsearcher plugin returns error if start_timestamp < 0 req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?start_timestamp=-1", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) // CSV export with error exportFunc = func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?start_timestamp=-1&csv_export=true", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) } exportFunc() req, err = http.NewRequest(http.MethodGet, logEventsPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) events = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &events) assert.NoError(t, err) if assert.Len(t, events, 1) { ev := events[0] for _, field := range []string{"id", "timestamp", "event", "ip", "message", "role", "instance_id"} { _, ok := ev[field] assert.True(t, ok, field) } } req, err = http.NewRequest(http.MethodGet, logEventsPath+"?events=a,1", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // CSV export req, err = http.NewRequest(http.MethodGet, logEventsPath+"?csv_export=true", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "text/csv", rr.Header().Get("Content-Type")) // the test eventsearcher plugin returns error if start_timestamp < 0 req, err = http.NewRequest(http.MethodGet, logEventsPath+"?start_timestamp=-1", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) // CSV export with error exportFunc = func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, err = http.NewRequest(http.MethodGet, logEventsPath+"?start_timestamp=-1&csv_export=true", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) } exportFunc() req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?limit=2000", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, logEventsPath+"?limit=2000", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?start_timestamp=a", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?end_timestamp=a", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?order=ASSC", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?statuses=a,b", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?fs_provider=a", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, webEventsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestMFAErrors(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) assert.False(t, user.Filters.TOTPConfig.Enabled) userToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) adminToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) // invalid config name totpReq := generateTOTPRequest{ ConfigName: "invalid config name", } asJSON, err := json.Marshal(totpReq) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userTOTPGeneratePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, userToken) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // invalid JSON invalidJSON := []byte("not a JSON") req, err = http.NewRequest(http.MethodPost, userTOTPGeneratePath, bytes.NewBuffer(invalidJSON)) assert.NoError(t, err) setBearerForReq(req, userToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(invalidJSON)) assert.NoError(t, err) setBearerForReq(req, userToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(invalidJSON)) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodPost, adminTOTPValidatePath, bytes.NewBuffer(invalidJSON)) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // invalid TOTP config name userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: "missing name", Secret: kms.NewPlainSecret(xid.New().String()), Protocols: []string{common.ProtocolSSH}, } asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, userToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "totp: config name") // invalid TOTP secret userTOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: mfa.GetAvailableTOTPConfigNames()[0], Secret: nil, Protocols: []string{common.ProtocolSSH}, } asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, userToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "totp: secret is mandatory") // no protocol userTOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: mfa.GetAvailableTOTPConfigNames()[0], Secret: kms.NewPlainSecret(xid.New().String()), Protocols: nil, } asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, userToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "totp: specify at least one protocol") // invalid protocol userTOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: mfa.GetAvailableTOTPConfigNames()[0], Secret: kms.NewPlainSecret(xid.New().String()), Protocols: []string{common.ProtocolWebDAV}, } asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, userToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "totp: invalid protocol") adminTOTPConfig := dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: "", Secret: kms.NewPlainSecret("secret"), } asJSON, err = json.Marshal(adminTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "totp: config name is mandatory") adminTOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: mfa.GetAvailableTOTPConfigNames()[0], Secret: nil, } asJSON, err = json.Marshal(adminTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "totp: secret is mandatory") // invalid TOTP secret status userTOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: mfa.GetAvailableTOTPConfigNames()[0], Secret: kms.NewSecret(sdkkms.SecretStatusRedacted, "", "", ""), Protocols: []string{common.ProtocolSSH}, } asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, userToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // previous secret will be preserved and we have no secret saved assert.Contains(t, rr.Body.String(), "totp: secret is mandatory") req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "totp: secret is mandatory") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMFAInvalidSecret(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) userToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: mfa.GetAvailableTOTPConfigNames()[0], Secret: kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", user.Username), Protocols: []string{common.ProtocolSSH, common.ProtocolHTTP}, } user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, dataprovider.RecoveryCode{ Used: false, Secret: kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", user.Username), }) err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, userToken) rr := executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) assert.Contains(t, rr.Body.String(), "Unable to decrypt recovery codes") loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err := getCookieFromResponse(rr) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("passcode", "123456") req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusInternalServerError, rr.Code) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("recovery_code", "RC-123456") req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusInternalServerError, rr.Code) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", "authcode") req.SetBasicAuth(defaultUsername, defaultPassword) resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin, _, err = httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) admin.Password = altAdminPassword admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: mfa.GetAvailableTOTPConfigNames()[0], Secret: kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", user.Username), } admin.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, dataprovider.RecoveryCode{ Used: false, Secret: kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", user.Username), }) err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("passcode", "123456") req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusInternalServerError, rr.Code) csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("recovery_code", "RC-123456") req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusInternalServerError, rr.Code) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) assert.NoError(t, err) req.Header.Set("X-SFTPGO-OTP", "auth-code") req.SetBasicAuth(altAdminUsername, altAdminPassword) resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestWebUserTOTP(t *testing.T) { u := getTestUser() // TOTPConfig will be ignored on add u.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: "", Secret: kms.NewEmptySecret(), Protocols: []string{common.ProtocolSSH}, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.False(t, user.Filters.TOTPConfig.Enabled) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, userTOTPConfigsPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var configs []mfa.TOTPConfig err = json.Unmarshal(rr.Body.Bytes(), &configs) assert.NoError(t, err, rr.Body.String()) assert.Len(t, configs, len(mfa.GetAvailableTOTPConfigs())) totpConfig := configs[0] totpReq := generateTOTPRequest{ ConfigName: totpConfig.Name, } asJSON, err := json.Marshal(totpReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPGeneratePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var totpGenResp generateTOTPResponse err = json.Unmarshal(rr.Body.Bytes(), &totpGenResp) assert.NoError(t, err) assert.NotEmpty(t, totpGenResp.Secret) assert.NotEmpty(t, totpGenResp.QRCode) passcode, err := generateTOTPPasscode(totpGenResp.Secret) assert.NoError(t, err) validateReq := validateTOTPRequest{ ConfigName: totpGenResp.ConfigName, Passcode: passcode, Secret: totpGenResp.Secret, } asJSON, err = json.Marshal(validateReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPValidatePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // the same passcode cannot be reused req, err = http.NewRequest(http.MethodPost, userTOTPValidatePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "this passcode was already used") userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: totpGenResp.ConfigName, Secret: kms.NewPlainSecret(totpGenResp.Secret), Protocols: []string{common.ProtocolSSH}, } asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) totpCfg := user.Filters.TOTPConfig assert.True(t, totpCfg.Enabled) secretPayload := totpCfg.Secret.GetPayload() assert.Equal(t, totpGenResp.ConfigName, totpCfg.ConfigName) assert.Empty(t, totpCfg.Secret.GetKey()) assert.Empty(t, totpCfg.Secret.GetAdditionalData()) assert.NotEmpty(t, secretPayload) assert.Equal(t, sdkkms.SecretStatusSecretBox, totpCfg.Secret.GetStatus()) assert.Len(t, totpCfg.Protocols, 1) assert.Contains(t, totpCfg.Protocols, common.ProtocolSSH) // update protocols only userTOTPConfig = dataprovider.UserTOTPConfig{ Protocols: []string{common.ProtocolSSH, common.ProtocolFTP}, Secret: kms.NewEmptySecret(), } asJSON, err = json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // update the user, TOTP should not be affected user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: false, Secret: kms.NewEmptySecret(), } _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.TOTPConfig.Enabled) assert.Equal(t, totpCfg.ConfigName, user.Filters.TOTPConfig.ConfigName) assert.Empty(t, user.Filters.TOTPConfig.Secret.GetKey()) assert.Empty(t, user.Filters.TOTPConfig.Secret.GetAdditionalData()) assert.Equal(t, secretPayload, user.Filters.TOTPConfig.Secret.GetPayload()) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) assert.Len(t, user.Filters.TOTPConfig.Protocols, 2) assert.Contains(t, user.Filters.TOTPConfig.Protocols, common.ProtocolSSH) assert.Contains(t, user.Filters.TOTPConfig.Protocols, common.ProtocolFTP) req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var recCodes []recoveryCode err = json.Unmarshal(rr.Body.Bytes(), &recCodes) assert.NoError(t, err) assert.Len(t, recCodes, 12) // regenerate recovery codes req, err = http.NewRequest(http.MethodPost, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // check that recovery codes are different req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var newRecCodes []recoveryCode err = json.Unmarshal(rr.Body.Bytes(), &newRecCodes) assert.NoError(t, err) assert.Len(t, newRecCodes, 12) assert.NotEqual(t, recCodes, newRecCodes) // disable 2FA, the update user API should not work adminToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user.Filters.TOTPConfig.Enabled = false user.Filters.RecoveryCodes = nil user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, defaultUsername, user.Username) assert.True(t, user.Filters.TOTPConfig.Enabled) assert.Len(t, user.Filters.RecoveryCodes, 12) // use the dedicated API req, err = http.NewRequest(http.MethodPut, userPath+"/"+defaultUsername+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.False(t, user.Filters.TOTPConfig.Enabled) assert.Len(t, user.Filters.RecoveryCodes, 0) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userPath+"/"+defaultUsername+"/2fa/disable", nil) assert.NoError(t, err) setBearerForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, user2FARecoveryCodesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestWebAPIChangeUserProfileMock(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) assert.False(t, user.Filters.AllowAPIKeyAuth) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) // invalid json req, err := http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer([]byte("{"))) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) email := "userapi@example.com" additionalEmails := []string{"userapi1@example.com"} description := "user API description" profileReq := make(map[string]any) profileReq["allow_api_key_auth"] = true profileReq["email"] = email profileReq["description"] = description profileReq["public_keys"] = []string{testPubKey, testPubKey1} profileReq["tls_certs"] = []string{httpsCert} profileReq["additional_emails"] = additionalEmails asJSON, err := json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) profileReq = make(map[string]any) req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = json.Unmarshal(rr.Body.Bytes(), &profileReq) assert.NoError(t, err) assert.Equal(t, email, profileReq["email"].(string)) assert.Len(t, profileReq["additional_emails"].([]interface{}), 1) assert.Equal(t, description, profileReq["description"].(string)) assert.True(t, profileReq["allow_api_key_auth"].(bool)) val, ok := profileReq["public_keys"].([]any) if assert.True(t, ok, profileReq) { assert.Len(t, val, 2) } val, ok = profileReq["tls_certs"].([]any) if assert.True(t, ok, profileReq) { assert.Len(t, val, 1) } // set an invalid email profileReq = make(map[string]any) profileReq["email"] = "notavalidemail" asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Validation error: email") // set an invalid additional email profileReq = make(map[string]any) profileReq["additional_emails"] = []string{"not an email"} asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Validation error: email") // set an invalid public key profileReq = make(map[string]any) profileReq["public_keys"] = []string{"not a public key"} asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Validation error: error parsing public key") // set an invalid TLS certificate profileReq = make(map[string]any) profileReq["tls_certs"] = []string{"not a TLS cert"} asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Validation error: invalid TLS certificate") user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientPubKeyChangeDisabled, sdk.WebClientTLSCertChangeDisabled} user.Email = email user.Description = description user.Filters.AllowAPIKeyAuth = true _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) profileReq = make(map[string]any) profileReq["allow_api_key_auth"] = false profileReq["email"] = email profileReq["description"] = description + "_mod" //nolint:goconst profileReq["public_keys"] = []string{testPubKey} profileReq["tls_certs"] = []string{} asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), "Profile updated") // check that api key auth and public keys were not changed profileReq = make(map[string]any) req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = json.Unmarshal(rr.Body.Bytes(), &profileReq) assert.NoError(t, err) assert.Equal(t, email, profileReq["email"].(string)) assert.Equal(t, description+"_mod", profileReq["description"].(string)) assert.True(t, profileReq["allow_api_key_auth"].(bool)) val, ok = profileReq["public_keys"].([]any) if assert.True(t, ok, profileReq) { assert.Len(t, val, 2) } val, ok = profileReq["tls_certs"].([]any) if assert.True(t, ok, profileReq) { assert.Len(t, val, 1) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled} user.Description = description + "_mod" _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) profileReq = make(map[string]any) profileReq["allow_api_key_auth"] = false profileReq["email"] = "newemail@apiuser.com" profileReq["description"] = description profileReq["public_keys"] = []string{testPubKey} asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) profileReq = make(map[string]any) req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = json.Unmarshal(rr.Body.Bytes(), &profileReq) assert.NoError(t, err) assert.Equal(t, email, profileReq["email"].(string)) assert.Equal(t, description+"_mod", profileReq["description"].(string)) assert.True(t, profileReq["allow_api_key_auth"].(bool)) assert.Len(t, profileReq["public_keys"].([]any), 1) // finally disable all profile permissions user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled, sdk.WebClientPubKeyChangeDisabled, sdk.WebClientTLSCertChangeDisabled} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "You are not allowed to change anything") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestPermGroupOverride(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.WebClient = []string{sdk.WebClientPasswordChangeDisabled} group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) pwd := make(map[string]string) pwd["current_password"] = defaultPassword pwd["new_password"] = altAdminPassword asJSON, err := json.Marshal(pwd) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) group.UserSettings.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP} group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) assert.NoError(t, err) token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols") req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nError2FARequired) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestWebAPIChangeUserPwdMock(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // invalid json req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer([]byte("{"))) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) pwd := make(map[string]string) pwd["current_password"] = defaultPassword pwd["new_password"] = defaultPassword asJSON, err := json.Marshal(pwd) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the new password must be different from the current one") pwd["new_password"] = altAdminPassword asJSON, err = json.Marshal(pwd) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, altAdminPassword) assert.NoError(t, err) assert.NotEmpty(t, token) // remove the change password permission user.Filters.WebClient = []string{sdk.WebClientPasswordChangeDisabled} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Len(t, user.Filters.WebClient, 1) assert.Contains(t, user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, altAdminPassword) assert.NoError(t, err) assert.NotEmpty(t, token) pwd["current_password"] = altAdminPassword pwd["new_password"] = defaultPassword asJSON, err = json.Marshal(pwd) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginInvalidPasswordMock(t *testing.T) { _, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass+"1") assert.Error(t, err) // now a login with no credentials req, _ := http.NewRequest(http.MethodGet, "/api/v2/token", nil) rr := executeRequest(req) assert.Equal(t, http.StatusUnauthorized, rr.Code) } func TestWebAPIChangeAdminProfileMock(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) assert.False(t, admin.Filters.AllowAPIKeyAuth) token, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) // invalid json req, err := http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer([]byte("{"))) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) email := "adminapi@example.com" description := "admin API description" profileReq := make(map[string]any) profileReq["allow_api_key_auth"] = true profileReq["email"] = email profileReq["description"] = description asJSON, err := json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), "Profile updated") profileReq = make(map[string]any) req, err = http.NewRequest(http.MethodGet, adminProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = json.Unmarshal(rr.Body.Bytes(), &profileReq) assert.NoError(t, err) assert.Equal(t, email, profileReq["email"].(string)) assert.Equal(t, description, profileReq["description"].(string)) assert.True(t, profileReq["allow_api_key_auth"].(bool)) // set an invalid email profileReq["email"] = "admin_invalid_email" asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Validation error: email") _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, adminProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestChangeAdminPwdMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminDeleteUsers} asJSON, err := json.Marshal(admin) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) pwd := make(map[string]string) pwd["current_password"] = altAdminPassword pwd["new_password"] = defaultTokenAuthPass asJSON, err = json.Marshal(pwd) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, adminPwdPath, bytes.NewBuffer(asJSON)) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // try using the old token req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) _, err = getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) assert.Error(t, err) altToken, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, adminPwdPath, bytes.NewBuffer(asJSON)) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // current password does not match req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(adminPath, altAdminUsername), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestUpdateAdminMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) _, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) assert.Error(t, err) admin := getTestAdmin() admin.Username = altAdminUsername admin.Permissions = []string{dataprovider.PermAdminAny} asJSON, err := json.Marshal(admin) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, "abc"), bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername), bytes.NewBuffer([]byte("no json"))) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) admin.Permissions = nil asJSON, err = json.Marshal(admin) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername), bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) admin = getTestAdmin() admin.Status = 0 asJSON, err = json.Marshal(admin) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, defaultTokenAuthUser), bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "you cannot disable yourself") admin.Status = 1 admin.Permissions = []string{dataprovider.PermAdminAddUsers} asJSON, err = json.Marshal(admin) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, defaultTokenAuthUser), bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "you cannot change your permissions") admin.Permissions = []string{dataprovider.PermAdminAny} admin.Role = "missing role" asJSON, err = json.Marshal(admin) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, defaultTokenAuthUser), bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "you cannot add/change your role") admin.Role = "" altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) assert.NoError(t, err) admin.Password = "" // it must remain unchanged admin.Permissions = []string{dataprovider.PermAdminAny} asJSON, err = json.Marshal(admin) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername), bytes.NewBuffer(asJSON)) setBearerForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodDelete, path.Join(adminPath, altAdminUsername), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestAdminLastLoginWithAPIKey(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Filters.AllowAPIKeyAuth = true admin, resp, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Equal(t, int64(0), admin.LastLogin) apiKey := dataprovider.APIKey{ Name: "admin API key", Scope: dataprovider.APIKeyScopeAdmin, Admin: altAdminUsername, LastUseAt: 123, } apiKey, resp, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Equal(t, int64(0), apiKey.LastUseAt) req, err := http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, admin.Username) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.Greater(t, admin.LastLogin, int64(0)) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestUserLastLoginWithAPIKey(t *testing.T) { user := getTestUser() user.Filters.AllowAPIKeyAuth = true user, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) assert.Equal(t, int64(0), user.LastLogin) apiKey := dataprovider.APIKey{ Name: "user API key", Scope: dataprovider.APIKeyScopeUser, User: user.Username, } apiKey, resp, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) assert.NoError(t, err, string(resp)) req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.LastLogin, int64(0)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestAdminHandlingWithAPIKeys(t *testing.T) { sysAdmin, _, err := httpdtest.GetAdminByUsername(defaultTokenAuthUser, http.StatusOK) assert.NoError(t, err) sysAdmin.Filters.AllowAPIKeyAuth = true sysAdmin, _, err = httpdtest.UpdateAdmin(sysAdmin, http.StatusOK) assert.NoError(t, err) apiKey := dataprovider.APIKey{ Name: "test admin API key", Scope: dataprovider.APIKeyScopeAdmin, Admin: defaultTokenAuthUser, } apiKey, resp, err := httpdtest.AddAPIKey(apiKey, http.StatusCreated) assert.NoError(t, err, string(resp)) admin := getTestAdmin() admin.Username = altAdminUsername asJSON, err := json.Marshal(admin) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) assert.NoError(t, err) admin.Filters.AllowAPIKeyAuth = true asJSON, err = json.Marshal(admin) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername), bytes.NewBuffer(asJSON)) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(adminPath, altAdminUsername), nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var adminGet dataprovider.Admin err = json.Unmarshal(rr.Body.Bytes(), &adminGet) assert.NoError(t, err) assert.True(t, adminGet.Filters.AllowAPIKeyAuth) req, err = http.NewRequest(http.MethodPut, path.Join(adminPath, defaultTokenAuthUser), bytes.NewBuffer(asJSON)) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "updating the admin impersonated with an API key is not allowed") // changing the password for the impersonated admin is not allowed pwd := make(map[string]string) pwd["current_password"] = defaultTokenAuthPass pwd["new_password"] = altAdminPassword asJSON, err = json.Marshal(pwd) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, adminPwdPath, bytes.NewBuffer(asJSON)) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "API key authentication is not allowed") req, err = http.NewRequest(http.MethodDelete, path.Join(adminPath, defaultTokenAuthUser), nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "you cannot delete yourself") req, err = http.NewRequest(http.MethodDelete, path.Join(adminPath, altAdminUsername), nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) assert.NoError(t, err) dbAdmin, err := dataprovider.AdminExists(defaultTokenAuthUser) assert.NoError(t, err) dbAdmin.Filters.AllowAPIKeyAuth = false err = dataprovider.UpdateAdmin(&dbAdmin, "", "", "") assert.NoError(t, err) sysAdmin, _, err = httpdtest.GetAdminByUsername(defaultTokenAuthUser, http.StatusOK) assert.NoError(t, err) assert.False(t, sysAdmin.Filters.AllowAPIKeyAuth) } func TestUserHandlingWithAPIKey(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Filters.AllowAPIKeyAuth = true admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) apiKey := dataprovider.APIKey{ Name: "test admin API key", Scope: dataprovider.APIKeyScopeAdmin, Admin: admin.Username, } apiKey, _, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, err := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) user.Filters.DisableFsChecks = true user.Description = "desc" userAsJSON = getUserAsJSON(t, user) req, err = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updatedUser dataprovider.User err = json.Unmarshal(rr.Body.Bytes(), &updatedUser) assert.NoError(t, err) assert.True(t, updatedUser.Filters.DisableFsChecks) assert.Equal(t, user.Description, updatedUser.Description) req, err = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusNotFound) assert.NoError(t, err) } func TestUpdateUserQuotaUsageMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) var user dataprovider.User u := getTestUser() usedQuotaFiles := 1 usedQuotaSize := int64(65535) u.UsedQuotaFiles = usedQuotaFiles u.UsedQuotaSize = usedQuotaSize u.QuotaFiles = 100 userAsJSON := getUserAsJSON(t, u) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, usedQuotaSize, user.UsedQuotaSize) // now update only quota size u.UsedQuotaFiles = 0 userAsJSON = getUserAsJSON(t, u) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage")+"?mode=add", bytes.NewBuffer(userAsJSON)) //nolint:goconst setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, usedQuotaSize*2, user.UsedQuotaSize) // only quota files u.UsedQuotaFiles = usedQuotaFiles u.UsedQuotaSize = 0 userAsJSON = getUserAsJSON(t, u) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage")+"?mode=add", bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles*2, user.UsedQuotaFiles) assert.Equal(t, usedQuotaSize*2, user.UsedQuotaSize) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer([]byte("string"))) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.True(t, common.QuotaScans.AddUserQuotaScan(user.Username, "")) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username)) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestUserPermissionsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() user.Permissions = make(map[string][]string) user.Permissions["/somedir"] = []string{dataprovider.PermAny} userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} user.Permissions[".."] = []string{dataprovider.PermAny} userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) user.Permissions["/somedir"] = []string{"invalid"} userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) delete(user.Permissions, "/somedir") user.Permissions["/somedir/.."] = []string{dataprovider.PermAny} userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) delete(user.Permissions, "/somedir/..") user.Permissions["not_abs_path"] = []string{dataprovider.PermAny} userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) delete(user.Permissions, "not_abs_path") user.Permissions["/somedir/../otherdir/"] = []string{dataprovider.PermListItems} userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updatedUser dataprovider.User err = render.DecodeJSON(rr.Body, &updatedUser) assert.NoError(t, err) if val, ok := updatedUser.Permissions["/otherdir"]; ok { assert.True(t, slices.Contains(val, dataprovider.PermListItems)) assert.Equal(t, 1, len(val)) } else { assert.Fail(t, "expected dir not found in permissions") } req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestUpdateUserInvalidJsonMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer([]byte("Invalid json"))) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestUpdateUserInvalidParamsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) user.HomeDir = "" userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) userID := user.ID user.ID = 0 user.CreatedAt = 0 userAsJSON = getUserAsJSON(t, user) req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) user.ID = userID req, _ = http.NewRequest(http.MethodPut, userPath+"/0", bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestGetAdminsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) admin := getTestAdmin() admin.Username = altAdminUsername asJSON, err := json.Marshal(admin) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=510&offset=0&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var admins []dataprovider.Admin err = render.DecodeJSON(rr.Body, &admins) assert.NoError(t, err) assert.GreaterOrEqual(t, len(admins), 1) firtAdmin := admins[0].Username req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=510&offset=0&order=DESC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admins = nil err = render.DecodeJSON(rr.Body, &admins) assert.NoError(t, err) assert.GreaterOrEqual(t, len(admins), 1) assert.NotEqual(t, firtAdmin, admins[0].Username) req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=510&offset=1&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admins = nil err = render.DecodeJSON(rr.Body, &admins) assert.NoError(t, err) assert.GreaterOrEqual(t, len(admins), 1) assert.NotEqual(t, firtAdmin, admins[0].Username) req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=a&offset=0&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=1&offset=aa&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=1&offset=0&order=ASCa", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(adminPath, admin.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestGetUsersMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=510&offset=0&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var users []dataprovider.User err = render.DecodeJSON(rr.Body, &users) assert.NoError(t, err) assert.GreaterOrEqual(t, len(users), 1) req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=aa&offset=0&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=a&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=0&order=ASCc", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestDeleteUserInvalidParamsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodDelete, userPath+"/0", nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestGetQuotaScansMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, quotaScanPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestStartQuotaScanMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) _, err = os.Stat(user.HomeDir) if err == nil { err = os.Remove(user.HomeDir) assert.NoError(t, err) } // simulate a duplicate quota scan common.QuotaScans.AddUserQuotaScan(user.Username, "") req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username)) req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) waitForUsersQuotaScan(t, token) _, err = os.Stat(user.HomeDir) if err != nil && errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(user.HomeDir, os.ModePerm) assert.NoError(t, err) } req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) waitForUsersQuotaScan(t, token) req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) waitForUsersQuotaScan(t, token) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUpdateFolderQuotaUsageMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) mappedPath := filepath.Join(os.TempDir(), "vfolder") folderName := filepath.Base(mappedPath) f := vfs.BaseVirtualFolder{ MappedPath: mappedPath, Name: folderName, } usedQuotaFiles := 1 usedQuotaSize := int64(65535) f.UsedQuotaFiles = usedQuotaFiles f.UsedQuotaSize = usedQuotaSize var folder vfs.BaseVirtualFolder folderAsJSON, err := json.Marshal(f) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &folder) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"), bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var folderGet vfs.BaseVirtualFolder req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folderGet) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles, folderGet.UsedQuotaFiles) assert.Equal(t, usedQuotaSize, folderGet.UsedQuotaSize) // now update only quota size f.UsedQuotaFiles = 0 folderAsJSON, err = json.Marshal(f) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage")+"?mode=add", bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) folderGet = vfs.BaseVirtualFolder{} req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folderGet) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles, folderGet.UsedQuotaFiles) assert.Equal(t, usedQuotaSize*2, folderGet.UsedQuotaSize) // now update only quota files f.UsedQuotaSize = 0 f.UsedQuotaFiles = 1 folderAsJSON, err = json.Marshal(f) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage")+"?mode=add", bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) folderGet = vfs.BaseVirtualFolder{} req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folderGet) assert.NoError(t, err) assert.Equal(t, usedQuotaFiles*2, folderGet.UsedQuotaFiles) assert.Equal(t, usedQuotaSize*2, folderGet.UsedQuotaSize) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"), bytes.NewBuffer([]byte("not a json"))) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.True(t, common.QuotaScans.AddVFolderQuotaScan(folderName)) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"), bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName)) req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestStartFolderQuotaScanMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) mappedPath := filepath.Join(os.TempDir(), "vfolder") folderName := filepath.Base(mappedPath) folder := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } folderAsJSON, err := json.Marshal(folder) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = os.Stat(mappedPath) if err == nil { err = os.Remove(mappedPath) assert.NoError(t, err) } // simulate a duplicate quota scan common.QuotaScans.AddVFolderQuotaScan(folderName) req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName)) // and now a real quota scan _, err = os.Stat(mappedPath) if err != nil && errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) } req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) waitForFoldersQuotaScanPath(t, token) // cleanup req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = os.RemoveAll(folderPath) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestStartQuotaScanNonExistentUserMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() req, _ := http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestStartQuotaScanNonExistentFolderMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) folder := vfs.BaseVirtualFolder{ Name: "afolder", } req, _ := http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestGetFoldersMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) mappedPath := filepath.Join(os.TempDir(), "vfolder") folderName := filepath.Base(mappedPath) folder := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } folderAsJSON, err := json.Marshal(folder) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &folder) assert.NoError(t, err) var folders []vfs.BaseVirtualFolder url, err := url.Parse(folderPath + "?limit=510&offset=0&order=DESC") assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, url.String(), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folders) assert.NoError(t, err) assert.GreaterOrEqual(t, len(folders), 1) req, _ = http.NewRequest(http.MethodGet, folderPath+"?limit=a&offset=0&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodGet, folderPath+"?limit=1&offset=a&order=ASC", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodGet, folderPath+"?limit=1&offset=0&order=ASCV", nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestGetVersionMock(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, versionPath, nil) rr := executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, versionPath, nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, versionPath, nil) setBearerForReq(req, "abcde") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) } func TestGetConnectionsMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, activeConnectionsPath, nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestGetStatusMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestDeleteActiveConnectionMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req.Header.Set(dataprovider.NodeTokenHeader, "Bearer abc") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "the provided token cannot be authenticated") req, err = http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID?node=node1", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestNotFoundMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, "/non/existing/path", nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestMethodNotAllowedMock(t *testing.T) { req, _ := http.NewRequest(http.MethodPost, activeConnectionsPath, nil) rr := executeRequest(req) checkResponseCode(t, http.StatusMethodNotAllowed, rr) } func TestHealthCheck(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, healthzPath, nil) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "ok", rr.Body.String()) } func TestGetWebRootMock(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "/", nil) rr := executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webBasePath, nil) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webBasePathAdmin, nil) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webBasePathClient, nil) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) } func TestWebNotFoundURI(t *testing.T) { urlString := httpBaseURL + webBasePath + "/a" req, err := http.NewRequest(http.MethodGet, urlString, nil) assert.NoError(t, err) resp, err := httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusNotFound, resp.StatusCode) req, err = http.NewRequest(http.MethodGet, urlString, nil) assert.NoError(t, err) setJWTCookieForReq(req, "invalid token") resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusNotFound, resp.StatusCode) urlString = httpBaseURL + webBasePathClient + "/a" req, err = http.NewRequest(http.MethodGet, urlString, nil) assert.NoError(t, err) resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusNotFound, resp.StatusCode) req, err = http.NewRequest(http.MethodGet, urlString, nil) assert.NoError(t, err) setJWTCookieForReq(req, "invalid client token") resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusNotFound, resp.StatusCode) } func TestLogout(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, logoutPath, nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "Your token is no longer valid") } func TestDefenderAPIInvalidIDMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, path.Join(defenderHosts, "abc"), nil) // not hex id setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "invalid host id") } func TestTokenHeaderCookie(t *testing.T) { apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) setJWTCookieForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "no token found") req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) setBearerForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestTokenAudience(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "Your token audience is not valid") req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) setJWTCookieForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) } func TestWebAPILoginMock(t *testing.T) { _, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername+"1", defaultPassword) assert.Error(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword+"1") assert.Error(t, err) apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) // a web token is not valid for API usage req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "Your token audience is not valid") req, err = http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // API token is not valid for web usage req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) setJWTCookieForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebClientLoginMock(t *testing.T) { _, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) // a web token is not valid for API or WebAdmin usage req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "Your token audience is not valid") req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) // bearer should not work req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) setBearerForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webClientPingPath, nil) req.RemoteAddr = defaultRemoteAddr setBearerForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) // now try to render client pages req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webClientPingPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now logout req, _ = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webClientPingPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) // get a new token and use it after removing the user webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) apiUserToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) setJWTCookieForReq(req, webToken) req.RemoteAddr = defaultRemoteAddr rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorGetUser) req, _ = http.NewRequest(http.MethodGet, webClientDirsPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorDirListUser) form := make(url.Values) form.Set("files", `[]`) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorGetUser) req, _ = http.NewRequest(http.MethodGet, userDirsPath, nil) setBearerForReq(req, apiUserToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") req, _ = http.NewRequest(http.MethodGet, userFilesPath, nil) setBearerForReq(req, apiUserToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") req, _ = http.NewRequest(http.MethodPost, userStreamZipPath, bytes.NewBuffer([]byte(`{}`))) setBearerForReq(req, apiUserToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") form = make(url.Values) form.Set("public_keys", testPubKey) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) } func TestWebClientLoginErrorsMock(t *testing.T) { form := getLoginForm("", "", "") req, _ := http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) form = getLoginForm(defaultUsername, defaultPassword, "") req, _ = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) } func TestWebClientMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now add a fake connection fs := vfs.NewOsFs("id", os.TempDir(), "", nil) connection := &httpd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), } err = common.Connections.Add(connection) assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.MaxTotalConnections = oldValue } func TestTokenInvalidIPAddress(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) req.RemoteAddr = "1.1.1.2" rr := executeRequest(req) checkResponseCode(t, http.StatusFound, rr) apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil) assert.NoError(t, err) req.RemoteAddr = "2.2.2.2" setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "Your token is not valid") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestDefender(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) remoteAddr := "172.16.5.6:9876" webAdminToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, defaultPassword, remoteAddr) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) req.RemoteAddr = remoteAddr rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) for i := 0; i < 3; i++ { _, err = getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, "wrong pwd", remoteAddr) assert.Error(t, err) } _, err = getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, defaultPassword, remoteAddr) assert.Error(t, err) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) req.RemoteAddr = remoteAddr rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorIPForbidden) req, _ = http.NewRequest(http.MethodGet, webUsersPath, nil) req.RequestURI = webUsersPath setJWTCookieForReq(req, webAdminToken) req.RemoteAddr = remoteAddr rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorIPForbidden) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.Header.Set("X-Real-IP", "127.0.0.1:2345") setJWTCookieForReq(req, webToken) req.RemoteAddr = remoteAddr rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "your IP address is blocked") // requests for static files should be always allowed req, err = http.NewRequest(http.MethodGet, "/static/favicon.png", nil) assert.NoError(t, err) req.RemoteAddr = remoteAddr rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Empty(t, rr.Header().Get("Cache-Control")) req, err = http.NewRequest(http.MethodGet, "/.well-known/acme-challenge/foo", nil) assert.NoError(t, err) req.RemoteAddr = remoteAddr rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Equal(t, "no-cache, no-store, max-age=0, must-revalidate, private", rr.Header().Get("Cache-Control")) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestPostConnectHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } common.Config.PostConnectHook = postConnectPath u := getTestUser() u.Filters.AllowAPIKeyAuth = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) apiKey, _, err := httpdtest.AddAPIKey(dataprovider.APIKey{ Name: "name", Scope: dataprovider.APIKeyScopeUser, User: user.Username, }, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(postConnectPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = os.WriteFile(postConnectPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.PostConnectHook = "" } func TestMaxSessions(t *testing.T) { u := getTestUser() u.MaxSessions = 1 u.Email = "user@session.com" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) // now add a fake connection fs := vfs.NewOsFs("id", os.TempDir(), "", nil) connection := &httpd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), } err = common.Connections.Add(connection) assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) // try an user API call req, err := http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil) assert.NoError(t, err) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") // web client requests csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("files", `[]`) req, err = http.NewRequest(http.MethodPost, webClientDownloadZipPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), util.I18nError429Message) req, err = http.NewRequest(http.MethodGet, webClientDirsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorDirList429) req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=p", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), util.I18nError429Message) req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=file", nil) //nolint:goconst assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), util.I18nError429Message) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=file", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), util.I18nError429Message) // test reset password smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 3525, From: "notification@example.com", TemplatesPath: "templates", } err = smtpCfg.Initialize(configDir, true) assert.NoError(t, err) loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) lastResetCode = "" req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("password", defaultPassword) form.Set("confirm_password", defaultPassword) form.Set("code", lastResetCode) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nError429Message) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestMaxTransfers(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 2 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "test share", Scope: dataprovider.ShareScopeReadWrite, Paths: []string{"/"}, Password: defaultPassword, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) fileName := "testfile.txt" req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path="+fileName, bytes.NewBuffer([]byte(" "))) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) conn, sftpClient, err := getSftpClient(user) assert.NoError(t, err) defer conn.Close() defer sftpClient.Close() f1, err := sftpClient.Create("file1") assert.NoError(t, err) f2, err := sftpClient.Create("file2") assert.NoError(t, err) _, err = f1.Write([]byte(" ")) assert.NoError(t, err) _, err = f2.Write([]byte(" ")) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "filepre") assert.NoError(t, err) _, err = part.Write([]byte("file content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+fileName, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError403Message) req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path="+fileName, bytes.NewBuffer([]byte(" "))) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) body = new(bytes.Buffer) writer = multipart.NewWriter(body) part1, err := writer.CreateFormFile("filenames", "file11.txt") assert.NoError(t, err) _, err = part1.Write([]byte("file11 content")) assert.NoError(t, err) part2, err := writer.CreateFormFile("filenames", "file22.txt") assert.NoError(t, err) _, err = part2.Write([]byte("file22 content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader = bytes.NewReader(body.Bytes()) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) err = f1.Close() assert.NoError(t, err) err = f2.Close() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1000*time.Millisecond, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetTotalTransfers() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) common.Config.MaxPerHostConnections = oldValue } func TestWebConfigsMock(t *testing.T) { acmeConfig := config.GetACMEConfig() acmeConfig.CertsPath = filepath.Clean(os.TempDir()) err := acme.Initialize(acmeConfig, configDir, true) require.NoError(t, err) err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webConfigsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form := make(url.Values) b, contentType, err := getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // parse form error csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webConfigsPath+"?p=p%C3%AO%GH", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // save SFTP configs form.Set("sftp_host_key_algos", ssh.KeyAlgoRSA) form.Add("sftp_host_key_algos", ssh.InsecureCertAlgoDSAv01) //nolint:staticcheck form.Set("sftp_pub_key_algos", ssh.InsecureKeyAlgoDSA) //nolint:staticcheck form.Set("form_action", "sftp_submit") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) // invalid algo form.Set("sftp_host_key_algos", ssh.KeyAlgoRSA) form.Add("sftp_host_key_algos", ssh.CertAlgoRSAv01) form.Set("sftp_pub_key_algos", ssh.InsecureKeyAlgoDSA) //nolint:staticcheck form.Set("sftp_kex_algos", "diffie-hellman-group18-sha512") form.Add("sftp_kex_algos", ssh.KeyExchangeDH16SHA512) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) // check SFTP configs configs, err := dataprovider.GetConfigs() assert.NoError(t, err) assert.Len(t, configs.SFTPD.HostKeyAlgos, 1) assert.Contains(t, configs.SFTPD.HostKeyAlgos, ssh.KeyAlgoRSA) assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) assert.Contains(t, configs.SFTPD.PublicKeyAlgos, ssh.InsecureKeyAlgoDSA) //nolint:staticcheck assert.Len(t, configs.SFTPD.KexAlgorithms, 1) assert.Contains(t, configs.SFTPD.KexAlgorithms, ssh.KeyExchangeDH16SHA512) // invalid form action form.Set("form_action", "") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nError400Message) // test SMTP configs form.Set("form_action", "smtp_submit") form.Set("smtp_host", "mail.example.net") form.Set("smtp_from", "Example ") form.Set("smtp_username", defaultUsername) form.Set("smtp_password", defaultPassword) form.Set("smtp_domain", "localdomain") form.Set("smtp_auth", "100") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) // invalid smtp_auth // set valid parameters form.Set("smtp_port", "a") // converted to 587 form.Set("smtp_auth", "1") form.Set("smtp_encryption", "2") form.Set("smtp_debug", "checked") form.Set("smtp_oauth2_provider", "1") form.Set("smtp_oauth2_client_id", "123") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) // check configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Len(t, configs.SFTPD.HostKeyAlgos, 1) assert.Contains(t, configs.SFTPD.HostKeyAlgos, ssh.KeyAlgoRSA) assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) assert.Contains(t, configs.SFTPD.PublicKeyAlgos, ssh.InsecureKeyAlgoDSA) //nolint:staticcheck assert.Equal(t, "mail.example.net", configs.SMTP.Host) assert.Equal(t, 587, configs.SMTP.Port) assert.Equal(t, "Example ", configs.SMTP.From) assert.Equal(t, defaultUsername, configs.SMTP.User) assert.Equal(t, 1, configs.SMTP.Debug) assert.Equal(t, "", configs.SMTP.OAuth2.ClientID) err = configs.SMTP.Password.Decrypt() assert.NoError(t, err) assert.Equal(t, defaultPassword, configs.SMTP.Password.GetPayload()) assert.Equal(t, 1, configs.SMTP.AuthType) assert.Equal(t, 2, configs.SMTP.Encryption) assert.Equal(t, "localdomain", configs.SMTP.Domain) // set a redacted password, the current password must be preserved form.Set("smtp_password", redactedSecret) form.Set("smtp_auth", "") configs.SMTP.AuthType = 0 // empty will be converted to 0 b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) updatedConfigs, err := dataprovider.GetConfigs() assert.NoError(t, err) encryptedPayload := updatedConfigs.SMTP.Password.GetPayload() secretKey := updatedConfigs.SMTP.Password.GetKey() err = updatedConfigs.SMTP.Password.Decrypt() assert.NoError(t, err) assert.Equal(t, configs.SFTPD, updatedConfigs.SFTPD) assert.Equal(t, configs.SMTP, updatedConfigs.SMTP) // now set an undecryptable password updatedConfigs.SMTP.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, encryptedPayload, secretKey, "") err = dataprovider.UpdateConfigs(&updatedConfigs, "", "", "") assert.NoError(t, err) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) form.Set("form_action", "acme_submit") form.Set("acme_port", "") // on error will be set to 80 form.Set("acme_protocols", "1") form.Add("acme_protocols", "2") form.Add("acme_protocols", "3") form.Set("acme_domain", "example.com") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) // no email set, validation will fail req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidEmail) form.Set("acme_domain", "") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) // check configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Len(t, configs.SFTPD.HostKeyAlgos, 1) assert.Contains(t, configs.SFTPD.HostKeyAlgos, ssh.KeyAlgoRSA) assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) assert.Contains(t, configs.SFTPD.PublicKeyAlgos, ssh.InsecureKeyAlgoDSA) //nolint:staticcheck assert.Equal(t, 80, configs.ACME.HTTP01Challenge.Port) assert.Equal(t, 7, configs.ACME.Protocols) assert.Empty(t, configs.ACME.Domain) assert.Empty(t, configs.ACME.Email) assert.True(t, configs.ACME.HasProtocol(common.ProtocolFTP)) assert.True(t, configs.ACME.HasProtocol(common.ProtocolWebDAV)) assert.True(t, configs.ACME.HasProtocol(common.ProtocolHTTP)) // create certificate files, so no real ACME call is done domain := "acme.example.com" crtPath := filepath.Join(acmeConfig.CertsPath, util.SanitizeDomain(domain)+".crt") keyPath := filepath.Join(acmeConfig.CertsPath, util.SanitizeDomain(domain)+".key") err = os.WriteFile(crtPath, nil, 0666) assert.NoError(t, err) err = os.WriteFile(keyPath, nil, 0666) assert.NoError(t, err) form.Set("acme_port", "402") form.Set("acme_protocols", "1") form.Add("acme_protocols", "1000") form.Set("acme_domain", domain) form.Set("acme_email", "email@example.com") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Len(t, configs.SFTPD.HostKeyAlgos, 1) assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) assert.Equal(t, 402, configs.ACME.HTTP01Challenge.Port) assert.Equal(t, 1, configs.ACME.Protocols) assert.Equal(t, domain, configs.ACME.Domain) assert.Equal(t, "email@example.com", configs.ACME.Email) assert.False(t, configs.ACME.HasProtocol(common.ProtocolFTP)) assert.False(t, configs.ACME.HasProtocol(common.ProtocolWebDAV)) assert.True(t, configs.ACME.HasProtocol(common.ProtocolHTTP)) err = os.Remove(crtPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestBrandingConfigMock(t *testing.T) { err := dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) webClientLogoPath := "/static/branding/webclient/logo.png" webClientFaviconPath := "/static/branding/webclient/favicon.png" webAdminLogoPath := "/static/branding/webadmin/logo.png" webAdminFaviconPath := "/static/branding/webadmin/favicon.png" // no custom log or favicon was set for _, p := range []string{webClientLogoPath, webClientFaviconPath, webAdminLogoPath, webAdminFaviconPath} { req, err := http.NewRequest(http.MethodGet, p, nil) assert.NoError(t, err) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("form_action", "branding_submit") form.Set("branding_webadmin_name", "Custom WebAdmin") form.Set("branding_webadmin_short_name", "WebAdmin") form.Set("branding_webadmin_disclaimer_name", "Admin disclaimer") form.Set("branding_webadmin_disclaimer_url", "invalid, not a URL") form.Set("branding_webclient_name", "Custom WebClient") form.Set("branding_webclient_short_name", "WebClient") form.Set("branding_webclient_disclaimer_name", "Client disclaimer") form.Set("branding_webclient_disclaimer_url", "https://example.com") b, contentType, err := getMultipartFormData(form, "", "") assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidDisclaimerURL) form.Set("branding_webadmin_disclaimer_url", "https://example.net") tmpFile := filepath.Join(os.TempDir(), util.GenerateUniqueID()+".png") err = createTestPNG(tmpFile, 512, 512, color.RGBA{100, 200, 200, 0xff}) assert.NoError(t, err) b, contentType, err = getMultipartFormData(form, "branding_webadmin_logo", tmpFile) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) // check configs, err := dataprovider.GetConfigs() assert.NoError(t, err) assert.Equal(t, "Custom WebAdmin", configs.Branding.WebAdmin.Name) assert.Equal(t, "WebAdmin", configs.Branding.WebAdmin.ShortName) assert.Equal(t, "Admin disclaimer", configs.Branding.WebAdmin.DisclaimerName) assert.Equal(t, "https://example.net", configs.Branding.WebAdmin.DisclaimerURL) assert.Equal(t, "Custom WebClient", configs.Branding.WebClient.Name) assert.Equal(t, "WebClient", configs.Branding.WebClient.ShortName) assert.Equal(t, "Client disclaimer", configs.Branding.WebClient.DisclaimerName) assert.Equal(t, "https://example.com", configs.Branding.WebClient.DisclaimerURL) assert.Greater(t, len(configs.Branding.WebAdmin.Logo), 0) assert.Len(t, configs.Branding.WebAdmin.Favicon, 0) assert.Len(t, configs.Branding.WebClient.Logo, 0) assert.Len(t, configs.Branding.WebClient.Favicon, 0) err = createTestPNG(tmpFile, 256, 256, color.RGBA{120, 220, 220, 0xff}) assert.NoError(t, err) form.Set("branding_webadmin_logo_remove", "0") // 0 preserves WebAdmin logo b, contentType, err = getMultipartFormData(form, "branding_webadmin_favicon", tmpFile) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Equal(t, "Custom WebAdmin", configs.Branding.WebAdmin.Name) assert.Equal(t, "WebAdmin", configs.Branding.WebAdmin.ShortName) assert.Equal(t, "Admin disclaimer", configs.Branding.WebAdmin.DisclaimerName) assert.Equal(t, "https://example.net", configs.Branding.WebAdmin.DisclaimerURL) assert.Equal(t, "Custom WebClient", configs.Branding.WebClient.Name) assert.Equal(t, "WebClient", configs.Branding.WebClient.ShortName) assert.Equal(t, "Client disclaimer", configs.Branding.WebClient.DisclaimerName) assert.Equal(t, "https://example.com", configs.Branding.WebClient.DisclaimerURL) assert.Greater(t, len(configs.Branding.WebAdmin.Logo), 0) assert.Greater(t, len(configs.Branding.WebAdmin.Favicon), 0) assert.Len(t, configs.Branding.WebClient.Logo, 0) assert.Len(t, configs.Branding.WebClient.Favicon, 0) err = createTestPNG(tmpFile, 256, 256, color.RGBA{80, 90, 110, 0xff}) assert.NoError(t, err) b, contentType, err = getMultipartFormData(form, "branding_webclient_logo", tmpFile) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Greater(t, len(configs.Branding.WebClient.Logo), 0) err = createTestPNG(tmpFile, 256, 256, color.RGBA{120, 50, 120, 0xff}) assert.NoError(t, err) b, contentType, err = getMultipartFormData(form, "branding_webclient_favicon", tmpFile) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Greater(t, len(configs.Branding.WebClient.Favicon), 0) for _, p := range []string{webClientLogoPath, webClientFaviconPath, webAdminLogoPath, webAdminFaviconPath} { req, err := http.NewRequest(http.MethodGet, p, nil) assert.NoError(t, err) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } // remove images form.Set("branding_webadmin_logo_remove", "1") form.Set("branding_webclient_logo_remove", "1") form.Set("branding_webadmin_favicon_remove", "1") form.Set("branding_webclient_favicon_remove", "1") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) configs, err = dataprovider.GetConfigs() assert.NoError(t, err) assert.Len(t, configs.Branding.WebAdmin.Logo, 0) assert.Len(t, configs.Branding.WebAdmin.Favicon, 0) assert.Len(t, configs.Branding.WebClient.Logo, 0) assert.Len(t, configs.Branding.WebClient.Favicon, 0) for _, p := range []string{webClientLogoPath, webClientFaviconPath, webAdminLogoPath, webAdminFaviconPath} { req, err := http.NewRequest(http.MethodGet, p, nil) assert.NoError(t, err) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } form.Del("branding_webadmin_logo_remove") form.Del("branding_webclient_logo_remove") form.Del("branding_webadmin_favicon_remove") form.Del("branding_webclient_favicon_remove") // image too large err = createTestPNG(tmpFile, 768, 512, color.RGBA{120, 50, 120, 0xff}) assert.NoError(t, err) b, contentType, err = getMultipartFormData(form, "branding_webclient_logo", tmpFile) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidPNGSize) // not a png image err = createTestFile(tmpFile, 128) assert.NoError(t, err) b, contentType, err = getMultipartFormData(form, "branding_webclient_logo", tmpFile) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidPNG) err = os.Remove(tmpFile) assert.NoError(t, err) err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestSFTPLoopError(t *testing.T) { user1 := getTestUser() user2 := getTestUser() user1.Username += "1" user1.Email = "user1@test.com" user2.Username += "2" user1.FsConfig = vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user2.Username, }, Password: kms.NewPlainSecret(defaultPassword), }, } user2.FsConfig.Provider = sdk.SFTPFilesystemProvider user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user1.Username, }, Password: kms.NewPlainSecret(defaultPassword), } user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) assert.NoError(t, err, string(resp)) user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) assert.NoError(t, err, string(resp)) // test reset password smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 3525, From: "notification@example.com", TemplatesPath: "templates", } err = smtpCfg.Initialize(configDir, true) assert.NoError(t, err) loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user1.Username) lastResetCode = "" req, err := http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("password", defaultPassword) form.Set("confirm_password", defaultPassword) form.Set("code", lastResetCode) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorLoginAfterReset) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user2.GetHomeDir()) assert.NoError(t, err) } func TestLoginInvalidFs(t *testing.T) { u := getTestUser() u.Filters.AllowAPIKeyAuth = true u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.UploadPartSize = 1 u.FsConfig.GCSConfig.UploadPartMaxTime = 10 u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) apiKey, _, err := httpdtest.AddAPIKey(dataprovider.APIKey{ Name: "testk", Scope: dataprovider.APIKeyScopeUser, User: u.Username, }, http.StatusCreated) assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr := executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebClientChangePwd(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webChangeClientPwdPath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form := make(url.Values) form.Set("current_password", defaultPassword) form.Set("new_password1", defaultPassword) form.Set("new_password2", defaultPassword) // no csrf token req, err = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, webToken) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdNoDifferent) form.Set("current_password", defaultPassword+"2") form.Set("new_password1", defaultPassword+"1") form.Set("new_password2", defaultPassword+"1") req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdCurrentNoMatch) form.Set("current_password", defaultPassword) form.Set("new_password1", defaultPassword+"1") form.Set("new_password2", defaultPassword+"1") req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) req, err = http.NewRequest(http.MethodGet, webClientPingPath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername+"1", defaultPassword+"1") assert.Error(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword+"1") assert.NoError(t, err) // remove the change password permission user.Filters.WebClient = []string{sdk.WebClientPasswordChangeDisabled} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Len(t, user.Filters.WebClient, 1) assert.Contains(t, user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword+"1") assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("current_password", defaultPassword+"1") form.Set("new_password1", defaultPassword) form.Set("new_password2", defaultPassword) req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPreDownloadHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldExecuteOn := common.Config.Actions.ExecuteOn oldHook := common.Config.Actions.Hook common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} common.Config.Actions.Hook = preActionPath u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(preActionPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) testFileName := "testfile" testFileContents := []byte("file contents") err = os.MkdirAll(filepath.Join(user.GetHomeDir()), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName), testFileContents, os.ModePerm) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) //nolint:goconst assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, testFileContents, rr.Body.Bytes()) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, testFileContents, rr.Body.Bytes()) err = os.WriteFile(preActionPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError403Message) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "permission denied") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook } func TestPreUploadHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldExecuteOn := common.Config.Actions.ExecuteOn oldHook := common.Config.Actions.Hook common.Config.Actions.ExecuteOn = []string{common.OperationPreUpload} common.Config.Actions.Hook = preActionPath u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(preActionPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "filepre") assert.NoError(t, err) _, err = part.Write([]byte("file content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=filepre", bytes.NewBuffer([]byte("single upload content"))) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = os.WriteFile(preActionPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=filepre", bytes.NewBuffer([]byte("single upload content"))) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook } func TestShareUsage(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) testFileName := "testfile.dat" testFileSize := int64(65536) testFilePath := filepath.Join(user.GetHomeDir(), testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "test share", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, Password: defaultPassword, MaxTokens: 2, ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Second)), } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, sharesPath+"/unknownid", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) req.SetBasicAuth(defaultUsername, "wrong password") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) time.Sleep(2 * time.Second) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID, nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"_mod", nil) assert.NoError(t, err) req.RequestURI = webClientPubSharesPath + "/" + objectID + "_mod" req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) share.ExpiresAt = 0 jsonReq := make(map[string]any) jsonReq["name"] = share.Name jsonReq["scope"] = share.Scope jsonReq["paths"] = share.Paths jsonReq["password"] = share.Password jsonReq["max_tokens"] = share.MaxTokens jsonReq["expires_at"] = share.ExpiresAt asJSON, err = json.Marshal(jsonReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userSharesPath+"/"+objectID, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID, nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "invalid share scope") share.MaxTokens = 3 share.Scope = dataprovider.ShareScopeWrite asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userSharesPath+"/"+objectID, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part1, err := writer.CreateFormFile("filenames", "file1.txt") assert.NoError(t, err) _, err = part1.Write([]byte("file1 content")) assert.NoError(t, err) part2, err := writer.CreateFormFile("filenames", "file2.txt") assert.NoError(t, err) _, err = part2.Write([]byte("file2 content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Unable to parse multipart form") _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) // set the proper content type req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Allowed usage exceeded") share.MaxTokens = 6 share.Scope = dataprovider.ShareScopeWrite asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, userSharesPath+"/"+objectID, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientPubSharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) share, err = dataprovider.ShareExists(objectID, user.Username) assert.NoError(t, err) assert.Equal(t, 6, share.UsedTokens) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) share.MaxTokens = 0 err = dataprovider.UpdateShare(&share, user.Username, "", "") assert.NoError(t, err) user.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "permission denied") user.Permissions["/"] = []string{dataprovider.PermAny} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) body = new(bytes.Buffer) writer = multipart.NewWriter(body) part, err := writer.CreateFormFile("filename", "file1.txt") assert.NoError(t, err) _, err = part.Write([]byte("file content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader = bytes.NewReader(body.Bytes()) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "No files uploaded!") user.Filters.WebClient = []string{sdk.WebClientSharesDisabled} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) user.Filters.WebClient = []string{sdk.WebClientShareNoPasswordDisabled} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) share.Password = "" err = dataprovider.UpdateShare(&share, user.Username, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "sharing without a password was disabled") user.Filters.WebClient = []string{sdk.WebClientInfoChangeDisabled} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) share.Scope = dataprovider.ShareScopeRead share.Paths = []string{"/missing1", "/missing2"} err = dataprovider.UpdateShare(&share, user.Username, "", "") assert.NoError(t, err) defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) share, err = dataprovider.ShareExists(objectID, user.Username) assert.NoError(t, err) assert.Equal(t, 6, share.UsedTokens) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) }() req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) executeRequest(req) } func TestSharePasswordPolicy(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.PasswordStrength = 70 group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: g.Name, Type: sdk.GroupTypePrimary, }, } u.Password = rand.Text() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, u.Password) assert.NoError(t, err) share := dataprovider.Share{ Name: util.GenerateUniqueID(), Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, Password: defaultPassword, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "insecure password") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestShareMaxExpiration(t *testing.T) { u := getTestUser() u.Filters.MaxSharesExpiration = 5 u.Filters.DefaultSharesExpiration = 10 _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) assert.Contains(t, string(resp), "must be less than or equal to max shares expiration") u.Filters.DefaultSharesExpiration = 0 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webClientToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) s := dataprovider.Share{ Name: "test share", Scope: dataprovider.ShareScopeRead, Password: defaultPassword, Paths: []string{"/"}, ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(u.Filters.MaxSharesExpiration+2))), } asJSON, err := json.Marshal(s) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "share must expire before") req, err = http.NewRequest(http.MethodPut, path.Join(userSharesPath, "shareID"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // expiresAt is mandatory s.ExpiresAt = 0 asJSON, err = json.Marshal(s) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "share must expire before") s.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(2 * time.Hour)) asJSON, err = json.Marshal(s) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) shareID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, shareID) s.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(u.Filters.MaxSharesExpiration+2))) asJSON, err = json.Marshal(s) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, path.Join(userSharesPath, shareID), bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "share must expire before") csrfToken, err := getCSRFTokenFromInternalPageMock(webClientSharePath, webClientToken) assert.NoError(t, err) form := make(url.Values) form.Set("name", s.Name) form.Set("scope", strconv.Itoa(int(s.Scope))) form.Set("max_tokens", "0") form.Set("paths[0][path]", "/") form.Set("expiration_date", time.Now().Add(24*time.Hour*time.Duration(u.Filters.MaxSharesExpiration+2)).Format("2006-01-02 15:04:05")) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webClientToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareExpirationOutOfRange) req, err = http.NewRequest(http.MethodPost, path.Join(webClientSharePath, shareID), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webClientToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareExpirationOutOfRange) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webClientToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorGetUser) } func TestWebClientShareCredentials(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) shareRead := dataprovider.Share{ Name: "test share read", Scope: dataprovider.ShareScopeRead, Password: defaultPassword, Paths: []string{"/"}, } shareWrite := dataprovider.Share{ Name: "test share write", Scope: dataprovider.ShareScopeReadWrite, Password: defaultPassword, Paths: []string{"/"}, } asJSON, err := json.Marshal(shareRead) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) shareReadID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, shareReadID) asJSON, err = json.Marshal(shareWrite) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) shareWriteID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, shareWriteID) uri := path.Join(webClientPubSharesPath, shareReadID, "browse") req, err = http.NewRequest(http.MethodGet, uri, nil) assert.NoError(t, err) req.RequestURI = uri rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) location := rr.Header().Get("Location") assert.Contains(t, location, url.QueryEscape(uri)) // get the login form req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) req.RequestURI = uri rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now set the user token, it is not valid for the share req, err = http.NewRequest(http.MethodGet, uri, nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) // get a share token form := make(url.Values) form.Set("share_password", defaultPassword) loginURI := path.Join(webClientPubSharesPath, shareReadID, "login") req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) // set the CSRF token loginCookie, csrfToken, err := getCSRFTokenMock(loginURI, defaultRemoteAddr) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nShareLoginOK) cookie := rr.Header().Get("Set-Cookie") cookie = strings.TrimPrefix(cookie, "jwt=") assert.NotEmpty(t, cookie) req, err = http.NewRequest(http.MethodGet, uri, nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // get the download page req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareReadID, "download?a=b"), nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // get the download page for a missing share req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, "invalidshareid", "download"), nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // the same cookie will not work for the other share req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareWriteID, "browse"), nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) // IP address does not match req, err = http.NewRequest(http.MethodGet, uri, nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, cookie) req.RemoteAddr = "1.2.3.4:1234" rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) // logout to a different share, the cookie is not valid. req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareWriteID, "logout"), nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) // logout req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareReadID, "logout"), nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) // the cookie is no longer valid req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareReadID, "download?b=c"), nil) assert.NoError(t, err) req.RequestURI = uri setJWTCookieForReq(req, cookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Contains(t, rr.Header().Get("Location"), "/login") // try to login with invalid credentials loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("share_password", "") req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) // login with the next param set loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("share_password", defaultPassword) nextURI := path.Join(webClientPubSharesPath, shareReadID, "browse") loginURI = path.Join(webClientPubSharesPath, shareReadID, fmt.Sprintf("login?next=%s", url.QueryEscape(nextURI))) req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, nextURI, rr.Header().Get("Location")) // try to login to a missing share loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) loginURI = path.Join(webClientPubSharesPath, "missing", "login") req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestShareMaxSessions(t *testing.T) { u := getTestUser() u.MaxSessions = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "test share max sessions read", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil) //nolint:goconst assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // add a fake connection fs := vfs.NewOsFs("id", os.TempDir(), "", nil) connection := &httpd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), } err = common.Connections.Add(connection) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/dirs", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=file.pdf"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), util.I18nError429Message) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/browse", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError429Message) req, err = http.NewRequest(http.MethodPost, webClientPubSharesPath+"/"+objectID+"/browse/exist", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "invalid share scope") req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/files?path=afile", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") form := make(url.Values) form.Set("files", `[]`) req, err = http.NewRequest(http.MethodPost, webClientPubSharesPath+"/"+objectID+"/partial", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), util.I18nError429Message) req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") req, err = http.NewRequest(http.MethodDelete, userSharesPath+"/"+objectID, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now test a write share share = dataprovider.Share{ Name: "test share max sessions write", Scope: dataprovider.ShareScopeWrite, Paths: []string{"/"}, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file.txt"), bytes.NewBuffer([]byte("content"))) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") body := new(bytes.Buffer) writer := multipart.NewWriter(body) part1, err := writer.CreateFormFile("filenames", "file1.txt") assert.NoError(t, err) _, err = part1.Write([]byte("file1 content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") share = dataprovider.Share{ Name: "test share max sessions read&write", Scope: dataprovider.ShareScopeReadWrite, Paths: []string{"/"}, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodPost, webClientPubSharesPath+"/"+objectID+"/browse/exist", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestShareUploadSingle(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "test share", Scope: dataprovider.ShareScopeWrite, Paths: []string{"/"}, Password: defaultPassword, MaxTokens: 0, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) content := []byte("shared file content") modTime := time.Now().Add(-12 * time.Hour) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file.txt"), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) req.Header.Set("X-SFTPGO-MTIME", strconv.FormatInt(util.GetTimeAsMsSinceEpoch(modTime), 10)) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) info, err := os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) if assert.NoError(t, err) { assert.InDelta(t, util.GetTimeAsMsSinceEpoch(modTime), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(1000)) } req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "upload"), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "file.txt"), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) info, err = os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) if assert.NoError(t, err) { assert.InDelta(t, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(3000)) } req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "dir", "file.dat"), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "%2F"), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "operation unsupported") err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "dir"), os.ModePerm) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "dir"), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "operation unsupported") // only uploads to the share root dir are allowed req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "dir", "file.dat"), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) share, err = dataprovider.ShareExists(objectID, user.Username) assert.NoError(t, err) assert.Equal(t, 2, share.UsedTokens) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file1.txt"), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestShareReadWrite(t *testing.T) { u := getTestUser() u.Filters.StartDirectory = path.Join("/start", "dir") u.Permissions["/start/dir/limited"] = []string{dataprovider.PermListItems} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) testFileName := "test.txt" testSubDirs := "/sub/dir" share := dataprovider.Share{ Name: "test share rw", Scope: dataprovider.ShareScopeReadWrite, Paths: []string{user.Filters.StartDirectory}, Password: defaultPassword, MaxTokens: 0, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) filesToCheck := make(map[string]any) filesToCheck["files"] = []string{testFileName} asJSON, err = json.Marshal(filesToCheck) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/browse/exist?path=%2F"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var fileList []any err = json.Unmarshal(rr.Body.Bytes(), &fileList) assert.NoError(t, err) assert.Len(t, fileList, 0) content := []byte("shared rw content") req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, testFileName), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) assert.FileExists(t, filepath.Join(user.GetHomeDir(), user.Filters.StartDirectory, testFileName)) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape(path.Join(testSubDirs, testFileName)), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape(path.Join(testSubDirs, testFileName))+"?mkdir_parents=true", bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) assert.FileExists(t, filepath.Join(user.GetHomeDir(), user.Filters.StartDirectory, testSubDirs, testFileName)) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape(path.Join("limited", "sub", testFileName))+"?mkdir_parents=true", bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/browse/exist?path=%2F"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) fileList = nil err = json.Unmarshal(rr.Body.Bytes(), &fileList) assert.NoError(t, err) assert.Len(t, fileList, 1) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+testFileName), nil) //nolint:goconst assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contentDisposition := rr.Header().Get("Content-Disposition") assert.NotEmpty(t, contentDisposition) form := make(url.Values) form.Set("files", fmt.Sprintf(`["%s"]`, testFileName)) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contentDisposition = rr.Header().Get("Content-Disposition") assert.NotEmpty(t, contentDisposition) assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) // parse form error req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial?path=p%C3%AO%GK"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // invalid files list form.Set("files", fmt.Sprintf(`[%s]`, testFileName)) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nError400Message) // missing directory req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial?path=missing"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nError400Message) req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape("../"+testFileName), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Uploading outside the share is not allowed") req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape("/../../"+testFileName), bytes.NewBuffer(content)) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Uploading outside the share is not allowed") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestShareUncompressed(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) testFileName := "testfile.dat" testFileSize := int64(65536) testFilePath := filepath.Join(user.GetHomeDir(), testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "test share", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, Password: defaultPassword, MaxTokens: 0, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) s, err := dataprovider.ShareExists(objectID, defaultUsername) assert.NoError(t, err) assert.Equal(t, int64(0), s.ExpiresAt) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID, nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) //nolint:goconst assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) share = dataprovider.Share{ Name: "test share1", Scope: dataprovider.ShareScopeRead, Paths: []string{testFileName}, Password: defaultPassword, MaxTokens: 0, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID, nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "application/octet-stream", rr.Header().Get("Content-Type")) share, err = dataprovider.ShareExists(objectID, user.Username) assert.NoError(t, err) assert.Equal(t, 2, share.UsedTokens) user.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) share, err = dataprovider.ShareExists(objectID, user.Username) assert.NoError(t, err) assert.Equal(t, 2, share.UsedTokens) user.Permissions["/"] = []string{dataprovider.PermAny} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestDownloadFromShareError(t *testing.T) { u := getTestUser() u.DownloadDataTransfer = 1 u.Filters.DefaultSharesExpiration = 10 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.UsedDownloadDataTransfer = 1024*1024 - 32768 _, err = httpdtest.UpdateTransferQuotaUsage(user, "add", http.StatusOK) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(1024*1024-32768), user.UsedDownloadDataTransfer) testFileName := "test_share_file.dat" testFileSize := int64(524288) testFilePath := filepath.Join(user.GetHomeDir(), testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "test share root browse", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, MaxTokens: 2, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) s, err := dataprovider.ShareExists(objectID, defaultUsername) assert.NoError(t, err) assert.Greater(t, s.ExpiresAt, int64(0)) defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) share, err = dataprovider.ShareExists(objectID, user.Username) assert.NoError(t, err) assert.Equal(t, 0, share.UsedTokens) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) }() req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+testFileName), nil) assert.NoError(t, err) executeRequest(req) } func TestBrowseShares(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) testFileName := "testsharefile.dat" testFileNameLink := "testsharefile.link" shareDir := "share" subDir := "sub" testFileSize := int64(65536) testFilePath := filepath.Join(user.GetHomeDir(), shareDir, testFileName) testLinkPath := filepath.Join(user.GetHomeDir(), shareDir, testFileNameLink) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(filepath.Join(user.GetHomeDir(), shareDir, subDir, testFileName), testFileSize) assert.NoError(t, err) err = os.Symlink(testFilePath, testLinkPath) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "test share browse", Scope: dataprovider.ShareScopeRead, Paths: []string{shareDir}, MaxTokens: 0, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "upload"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "invalid share scope") req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Please set the path to a valid file") req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents := make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &contents) assert.NoError(t, err) assert.Len(t, contents, 2) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &contents) assert.NoError(t, err) assert.Len(t, contents, 2) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"+subDir), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &contents) assert.NoError(t, err) assert.Len(t, contents, 1) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F.."), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPathInvalid) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=%2F.."), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F.."), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path=%2F..%2F"+testFileName), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+testFileName), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contentDisposition := rr.Header().Get("Content-Disposition") assert.NotEmpty(t, contentDisposition) form := make(url.Values) form.Set("files", `[]`) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial?path=%2F.."), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPathInvalid) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path="+testFileName), nil) //nolint:goconst assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contentDisposition = rr.Header().Get("Content-Disposition") assert.NotEmpty(t, contentDisposition) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+subDir+"%2F"+testFileName), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contentDisposition = rr.Header().Get("Content-Disposition") assert.NotEmpty(t, contentDisposition) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=missing"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path=missing"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=missing"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=missing"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+testFileNameLink), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path="+testFileNameLink), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "non regular files are not supported for shares") req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path="+testFileName), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=missing"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=%2F.."), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPathInvalid) fakePDF := []byte(`%PDF-1.6`) for i := 0; i < 128; i++ { fakePDF = append(fakePDF, []byte(fmt.Sprintf("%d", i))...) } pdfPath := filepath.Join(user.GetHomeDir(), shareDir, "test.pdf") pdfLinkPath := filepath.Join(user.GetHomeDir(), shareDir, "link.pdf") err = os.WriteFile(pdfPath, fakePDF, 0666) assert.NoError(t, err) err = os.Symlink(pdfPath, pdfLinkPath) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "viewpdf?path=test.pdf"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=test.pdf"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) s, err := dataprovider.ShareExists(objectID, defaultUsername) assert.NoError(t, err) usedTokens := s.UsedTokens assert.Greater(t, usedTokens, 0) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=link.pdf"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // downloading a symlink will fail, usage should not change s, err = dataprovider.ShareExists(objectID, defaultUsername) assert.NoError(t, err) assert.Equal(t, usedTokens, s.UsedTokens) // share a symlink share = dataprovider.Share{ Name: "test share browse", Scope: dataprovider.ShareScopeRead, Paths: []string{path.Join(shareDir, testFileNameLink)}, MaxTokens: 0, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) // uncompressed download should not work req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) // this share is not browsable, it does not contains a directory req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) form = make(url.Values) form.Set("files", `[]`) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareBrowseNoDir) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path="+testFileName), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareBrowseNoDir) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path="+testFileName), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the shared object is not a directory and so it is not browsable") req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareBrowseNoDir) // now test a missing shareID objectID = "123456" req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path="+testFileName), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) form = make(url.Values) form.Set("files", `[]`) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial?path=%2F"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "viewpdf?path=p"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=p"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // share a missing base path share = dataprovider.Share{ Name: "test share", Scope: dataprovider.ShareScopeRead, Paths: []string{path.Join(shareDir, "missingdir")}, MaxTokens: 0, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "unable to check the share directory") // share multiple paths share = dataprovider.Share{ Name: "test share", Scope: dataprovider.ShareScopeRead, Paths: []string{shareDir, "/anotherdir"}, MaxTokens: 0, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareBrowsePaths) share = dataprovider.Share{ Name: "test share rw", Scope: dataprovider.ShareScopeReadWrite, Paths: []string{"/missingdir"}, MaxTokens: 0, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/browse/exist?path=%2F"), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "unable to check the share directory") share = dataprovider.Share{ Name: "test share rw", Scope: dataprovider.ShareScopeReadWrite, Paths: []string{shareDir}, MaxTokens: 0, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/browse/exist?path=%2F.."), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Invalid path") // share the root path share = dataprovider.Share{ Name: "test share root", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, MaxTokens: 0, } asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &contents) assert.NoError(t, err) assert.Len(t, contents, 1) // if we require two-factor auth for HTTP protocol the share should not work anymore user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH, common.ProtocolHTTP} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "two-factor authentication requirements not met") user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // share read/write share.Scope = dataprovider.ShareScopeReadWrite asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // on upload we should be redirected req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "upload"), nil) assert.NoError(t, err) req.SetBasicAuth(defaultUsername, defaultPassword) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) location := rr.Header().Get("Location") assert.Equal(t, path.Join(webClientPubSharesPath, objectID, "browse"), location) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUserAPIShareErrors(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Scope: 1000, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "invalid scope") // invalid json req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer([]byte("{"))) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) share.Scope = dataprovider.ShareScopeWrite asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "at least a shared path is required") share.Paths = []string{"path1", "../path1", "/path2"} asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the write share scope requires exactly one path") share.Paths = []string{"", ""} asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "at least a shared path is required") share.Paths = []string{"path1", "../path1", "/path1"} share.Password = redactedSecret asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "cannot save a share with a redacted password") share.Password = "newpass" share.AllowFrom = []string{"not valid"} asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "could not parse allow from entry") share.AllowFrom = []string{"127.0.0.1/8"} share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-12 * time.Hour)) asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "expiration must be in the future") share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(12 * time.Hour)) asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) location := rr.Header().Get("Location") asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "name is mandatory") // invalid json req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer([]byte("}"))) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, userSharesPath+"?limit=a", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUserAPIShares(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) u := getTestUser() u.Username = altAdminUsername user1, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) token1, err := getJWTAPIUserTokenFromTestServer(user1.Username, defaultPassword) assert.NoError(t, err) // the share username will be set from the claims share := dataprovider.Share{ Name: "share1", Description: "description1", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, CreatedAt: 1, UpdatedAt: 2, LastUseAt: 3, ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(2 * time.Hour)), Password: defaultPassword, MaxTokens: 10, UsedTokens: 2, AllowFrom: []string{"192.168.1.0/24"}, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) location := rr.Header().Get("Location") assert.NotEmpty(t, location) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) assert.Equal(t, fmt.Sprintf("%v/%v", userSharesPath, objectID), location) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var shareGet dataprovider.Share err = json.Unmarshal(rr.Body.Bytes(), &shareGet) assert.NoError(t, err) assert.Equal(t, objectID, shareGet.ShareID) assert.Equal(t, share.Name, shareGet.Name) assert.Equal(t, share.Description, shareGet.Description) assert.Equal(t, share.Scope, shareGet.Scope) assert.Equal(t, share.Paths, shareGet.Paths) assert.Equal(t, int64(0), shareGet.LastUseAt) assert.Greater(t, shareGet.CreatedAt, share.CreatedAt) assert.Greater(t, shareGet.UpdatedAt, share.UpdatedAt) assert.Equal(t, share.ExpiresAt, shareGet.ExpiresAt) assert.Equal(t, share.MaxTokens, shareGet.MaxTokens) assert.Equal(t, 0, shareGet.UsedTokens) assert.Equal(t, share.Paths, shareGet.Paths) assert.Equal(t, redactedSecret, shareGet.Password) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) setBearerForReq(req, token1) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) s, err := dataprovider.ShareExists(objectID, defaultUsername) assert.NoError(t, err) match, err := s.CheckCredentials(defaultPassword) assert.True(t, match) assert.NoError(t, err) match, err = s.CheckCredentials(defaultPassword + "mod") assert.False(t, match) assert.Error(t, err) shareGet.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(3 * time.Hour)) asJSON, err = json.Marshal(shareGet) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) s, err = dataprovider.ShareExists(objectID, defaultUsername) assert.NoError(t, err) match, err = s.CheckCredentials(defaultPassword) assert.True(t, match) assert.NoError(t, err) match, err = s.CheckCredentials(defaultPassword + "mod") assert.False(t, match) assert.Error(t, err) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var shareGetNew dataprovider.Share err = json.Unmarshal(rr.Body.Bytes(), &shareGetNew) assert.NoError(t, err) assert.NotEqual(t, shareGet.UpdatedAt, shareGetNew.UpdatedAt) shareGet.UpdatedAt = shareGetNew.UpdatedAt assert.Equal(t, shareGet, shareGetNew) req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var shares []dataprovider.Share err = json.Unmarshal(rr.Body.Bytes(), &shares) assert.NoError(t, err) if assert.Len(t, shares, 1) { assert.Equal(t, shareGetNew, shares[0]) } err = dataprovider.UpdateShareLastUse(&shareGetNew, 2) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) shareGetNew = dataprovider.Share{} err = json.Unmarshal(rr.Body.Bytes(), &shareGetNew) assert.NoError(t, err) assert.Equal(t, 2, shareGetNew.UsedTokens, "share: %v", shareGetNew) assert.Greater(t, shareGetNew.LastUseAt, int64(0), "share: %v", shareGetNew) req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) assert.NoError(t, err) setBearerForReq(req, token1) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) shares = nil err = json.Unmarshal(rr.Body.Bytes(), &shares) assert.NoError(t, err) assert.Len(t, shares, 0) // set an empty password shareGet.Password = "" asJSON, err = json.Marshal(shareGet) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) shareGetNew = dataprovider.Share{} err = json.Unmarshal(rr.Body.Bytes(), &shareGetNew) assert.NoError(t, err) assert.Empty(t, shareGetNew.Password) req, err = http.NewRequest(http.MethodDelete, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) share.Name = "" asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) location = rr.Header().Get("Location") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) // the share should be deleted with the associated user req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodDelete, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) } func TestUsersAPISharesNoPasswordDisabled(t *testing.T) { u := getTestUser() u.Filters.WebClient = []string{sdk.WebClientShareNoPasswordDisabled} u.Filters.PasswordStrength = 70 u.Password = "ahpoo8baa6EeZieshies" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, u.Password) assert.NoError(t, err) share := dataprovider.Share{ Name: "s", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, } asJSON, err := json.Marshal(share) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "You are not authorized to share files/folders without a password") share.Password = defaultPassword asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) share.Password = "vi5eiJoovee5ya9yahpi" asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) location := rr.Header().Get("Location") assert.NotEmpty(t, location) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) assert.Equal(t, fmt.Sprintf("%v/%v", userSharesPath, objectID), location) share.Password = "" asJSON, err = json.Marshal(share) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "You are not authorized to share files/folders without a password") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUserAPIKey(t *testing.T) { u := getTestUser() u.Filters.AllowAPIKeyAuth = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) apiKey := dataprovider.APIKey{ Name: "testkey", User: user.Username + "1", Scope: dataprovider.APIKeyScopeUser, } _, _, err = httpdtest.AddAPIKey(apiKey, http.StatusBadRequest) assert.NoError(t, err) apiKey.User = user.Username apiKey, _, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) assert.NoError(t, err) adminAPIKey := dataprovider.APIKey{ Name: "testadminkey", Scope: dataprovider.APIKeyScopeAdmin, } adminAPIKey, _, err = httpdtest.AddAPIKey(adminAPIKey, http.StatusCreated) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "filenametest") assert.NoError(t, err) _, err = part.Write([]byte("test file content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setAPIKeyForReq(req, apiKey.Key, "") rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var dirEntries []map[string]any err = json.Unmarshal(rr.Body.Bytes(), &dirEntries) assert.NoError(t, err) assert.Len(t, dirEntries, 1) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, adminAPIKey.Key, user.Username) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) user.Status = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) user.Status = 1 user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) user.Filters.DeniedProtocols = []string{common.ProtocolFTP} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) apiKeyNew := dataprovider.APIKey{ Name: apiKey.Name, Scope: dataprovider.APIKeyScopeUser, } apiKeyNew, _, err = httpdtest.AddAPIKey(apiKeyNew, http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKeyNew.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) // now associate a user req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKeyNew.Key, user.Username) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now with a missing user req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKeyNew.Key, user.Username+"1") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) // empty user and key not associated to any user req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKeyNew.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) apiKeyNew.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) _, _, err = httpdtest.UpdateAPIKey(apiKeyNew, http.StatusOK) assert.NoError(t, err) // expired API key req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKeyNew.Key, user.Username) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveAPIKey(apiKeyNew, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAPIKey(adminAPIKey, http.StatusOK) assert.NoError(t, err) } func TestWebClientExistenceCheck(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webClientExistPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // no CSRF header req, err = http.NewRequest(http.MethodPost, webClientExistPath, bytes.NewBuffer([]byte(`[]`))) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) filesToCheck := make(map[string]any) filesToCheck["files"] = nil asJSON, err := json.Marshal(filesToCheck) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientExistPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "files to be checked are mandatory") testFileName := "file.dat" testDirName := "adirname" filesToCheck["files"] = []string{testFileName} asJSON, err = json.Marshal(filesToCheck) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2Fmissingdir", bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2F", bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var fileList []any err = json.Unmarshal(rr.Body.Bytes(), &fileList) assert.NoError(t, err) assert.Len(t, fileList, 0) err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName), 100) assert.NoError(t, err) err = os.Mkdir(filepath.Join(user.GetHomeDir(), testDirName), 0755) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2F", bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) fileList = nil err = json.Unmarshal(rr.Body.Bytes(), &fileList) assert.NoError(t, err) assert.Len(t, fileList, 1) filesToCheck["files"] = []string{testFileName, testDirName} asJSON, err = json.Marshal(filesToCheck) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2F", bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) fileList = nil err = json.Unmarshal(rr.Body.Bytes(), &fileList) assert.NoError(t, err) assert.Len(t, fileList, 2) req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2F"+testDirName, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) fileList = nil err = json.Unmarshal(rr.Body.Bytes(), &fileList) assert.NoError(t, err) assert.Len(t, fileList, 0) user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientExistPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebClientViewPDF(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientViewPDFPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, webClientViewPDFPath+"?path=test.pdf", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=test.pdf", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2F", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "test.pdf"), []byte("some text data"), 0666) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) //nolint:goconst assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) err = createTestFile(filepath.Join(user.GetHomeDir(), "test.pdf"), 1024) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) fakePDF := []byte(`%PDF-1.6`) for i := 0; i < 128; i++ { fakePDF = append(fakePDF, []byte(fmt.Sprintf("%d", i))...) } err = os.WriteFile(filepath.Join(user.GetHomeDir(), "test.pdf"), fakePDF, 0666) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", DeniedPatterns: []string{"*.pdf"}, }, } _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nError403Message) user.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", DeniedPatterns: []string{"*.txt"}, }, } user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestWebEditFile(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) testFile1 := "testfile1.txt" testFile2 := "testfile2" file1Size := int64(65536) file2Size := int64(1048576 * 5) err = createTestFile(filepath.Join(user.GetHomeDir(), testFile1), file1Size) assert.NoError(t, err) err = createTestFile(filepath.Join(user.GetHomeDir(), testFile2), file2Size) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile1, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile2, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorEditSize) req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=missing", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=%2F", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorEditDir) user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile1, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) user.Filters.DeniedProtocols = []string{common.ProtocolFTP} user.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", DeniedPatterns: []string{"*.txt"}, }, } _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile1, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nError403Message) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile1, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestWebGetFiles(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) testFileName := "testfile" testDir := "testdir" testFileContents := []byte("file contents") err = os.MkdirAll(filepath.Join(user.GetHomeDir(), testDir), os.ModePerm) assert.NoError(t, err) extensions := []string{"", ".doc", ".ppt", ".xls", ".pdf", ".mkv", ".png", ".go", ".zip", ".txt"} for _, ext := range extensions { err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName+ext), testFileContents, os.ModePerm) assert.NoError(t, err) } err = os.Symlink(filepath.Join(user.GetHomeDir(), testFileName+".doc"), filepath.Join(user.GetHomeDir(), testDir, testFileName+".link")) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testDir, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?path="+testDir, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var dirContents []map[string]any err = json.Unmarshal(rr.Body.Bytes(), &dirContents) assert.NoError(t, err) assert.Len(t, dirContents, 1) req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?dirtree=1&path="+testDir, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) dirContents = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &dirContents) assert.NoError(t, err) assert.Len(t, dirContents, 0) req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path="+testDir, nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var dirEntries []map[string]any err = json.Unmarshal(rr.Body.Bytes(), &dirEntries) assert.NoError(t, err) assert.Len(t, dirEntries, 1) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set("files", fmt.Sprintf(`["%s","%s","%s"]`, testFileName, testDir, testFileName+extensions[2])) req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path="+url.QueryEscape("/"), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // add csrf token form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path="+url.QueryEscape("/"), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // parse form error req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path=p%C3%AO%GK", bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) filesList := []string{testFileName, testDir, testFileName + extensions[2]} asJSON, err := json.Marshal(filesList) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPost, userStreamZipPath, bytes.NewBuffer(asJSON)) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPost, userStreamZipPath, bytes.NewBuffer([]byte(`file`))) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("files", fmt.Sprintf(`["%v"]`, testDir)) req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path="+url.QueryEscape("/"), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("files", "notalist") req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path="+url.QueryEscape("/"), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nError400Message) req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?path=/", nil) //nolint:goconst setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) dirContents = nil err = json.Unmarshal(rr.Body.Bytes(), &dirContents) assert.NoError(t, err) assert.Len(t, dirContents, len(extensions)+1) req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path=/", nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) dirEntries = nil err = json.Unmarshal(rr.Body.Bytes(), &dirEntries) assert.NoError(t, err) assert.Len(t, dirEntries, len(extensions)+1) req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?path=/missing", nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorDirListGeneric) req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path=missing", nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to get directory lister") req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, testFileContents, rr.Body.Bytes()) req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, testFileContents, rr.Body.Bytes()) req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path=", nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Please set the path to a valid file") req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testDir, nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "is a directory") req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path=notafile", nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to stat the requested file") req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=2-") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPartialContent, rr) assert.Equal(t, testFileContents[2:], rr.Body.Bytes()) lastModified, err := http.ParseTime(rr.Header().Get("Last-Modified")) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=2-") setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPartialContent, rr) assert.Equal(t, testFileContents[2:], rr.Body.Bytes()) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=-2") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPartialContent, rr) assert.Equal(t, testFileContents[11:], rr.Body.Bytes()) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=-2,") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestedRangeNotSatisfiable, rr) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=1a-") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestedRangeNotSatisfiable, rr) req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=2b-") setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestedRangeNotSatisfiable, rr) req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=2-") req.Header.Set("If-Range", lastModified.UTC().Format(http.TimeFormat)) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPartialContent, rr) req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=2-") req.Header.Set("If-Range", lastModified.UTC().Add(-120*time.Second).Format(http.TimeFormat)) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("If-Modified-Since", lastModified.UTC().Add(-120*time.Second).Format(http.TimeFormat)) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("If-Modified-Since", lastModified.UTC().Add(120*time.Second).Format(http.TimeFormat)) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotModified, rr) req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("If-Unmodified-Since", lastModified.UTC().Add(-120*time.Second).Format(http.TimeFormat)) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPreconditionFailed, rr) req, _ = http.NewRequest(http.MethodHead, userFilesPath+"?path="+testFileName, nil) req.Header.Set("If-Unmodified-Since", lastModified.UTC().Add(-120*time.Second).Format(http.TimeFormat)) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPreconditionFailed, rr) req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("If-Unmodified-Since", lastModified.UTC().Add(120*time.Second).Format(http.TimeFormat)) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} _, resp, err := httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(resp)) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?path=/", nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path="+testDir, nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) filesList = []string{testDir} asJSON, err = json.Marshal(filesList) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPost, userStreamZipPath, bytes.NewBuffer(asJSON)) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) user.Filters.DeniedProtocols = []string{common.ProtocolFTP} user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} _, resp, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(resp)) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) form = make(url.Values) form.Set("files", `[]`) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path="+testDir, nil) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRenameDifferentResource(t *testing.T) { folderName := "foldercryptfs" f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: filepath.Join(os.TempDir(), "folderName"), FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret("super secret"), }, }, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) u := getTestUser() u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/folderPath", }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileName := "file.txt" webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) getStatusResponse := func(taskID string) int { req, _ := http.NewRequest(http.MethodGet, webClientTasksPath+"/"+url.PathEscape(taskID), nil) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr := executeRequest(req) if rr.Code != http.StatusOK { return -1 } resp := make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &resp) if err != nil { return -1 } return int(resp["status"].(float64)) } assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Cannot perform copy step") req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) taskResp := make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &taskResp) assert.NoError(t, err) taskID := taskResp["message"].(string) assert.NotEmpty(t, taskID) assert.Eventually(t, func() bool { status := getStatusResponse(taskID) return status == http.StatusNotFound }, 1000*time.Millisecond, 100*time.Millisecond) err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName), []byte("just a test"), os.ModePerm) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // recreate the file and remove the delete permission err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName), []byte("just another test"), os.ModePerm) assert.NoError(t, err) u.Permissions = map[string][]string{ "/": {dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermCreateDirs, dataprovider.PermDownload, dataprovider.PermOverwrite}, } _, resp, err := httpdtest.UpdateUser(u, http.StatusOK, "") assert.NoError(t, err, string(resp)) webAPIToken, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Cannot perform copy step") u.Permissions = map[string][]string{ "/": {dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermCreateDirs, dataprovider.PermDownload, dataprovider.PermOverwrite, dataprovider.PermCopy}, } _, resp, err = httpdtest.UpdateUser(u, http.StatusOK, "") assert.NoError(t, err, string(resp)) webAPIToken, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Cannot perform remove step") req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) taskResp = make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &taskResp) assert.NoError(t, err) taskID = taskResp["message"].(string) assert.NotEmpty(t, taskID) assert.Eventually(t, func() bool { status := getStatusResponse(taskID) return status == http.StatusForbidden }, 1000*time.Millisecond, 100*time.Millisecond) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) } func TestWebDirsAPI(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) testDir := "testdir" req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var contents []map[string]any err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 0) // rename a missing folder req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target="+testDir+"new", nil) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // copy a missing folder req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path="+testDir+"%2F&target="+testDir+"new%2F", nil) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // delete a missing folder req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path="+testDir, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // create a dir req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+testDir, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) // check the dir was created req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = nil err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) if assert.Len(t, contents, 1) { assert.Equal(t, testDir, contents[0]["name"]) } // rename a dir with the same source and target name req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target="+testDir, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "operation unsupported") req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target=%2F"+testDir+"%2F", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "operation unsupported") // copy a dir with the same source and target name req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path="+testDir+"&target="+testDir, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "operation unsupported") // create a dir with missing parents req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+url.QueryEscape(path.Join("/sub/dir", testDir)), nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // setting the mkdir_parents param will work req, err = http.NewRequest(http.MethodPost, userDirsPath+"?mkdir_parents=true&path="+url.QueryEscape(path.Join("/sub/dir", testDir)), nil) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) // copy the dir req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path="+testDir+"&target="+testDir+"copy", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // rename the dir req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target="+testDir+"new", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // delete the dir req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path="+testDir+"new", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path="+testDir+"copy", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // the root dir cannot be created req, err = http.NewRequest(http.MethodPost, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) user.Permissions["/"] = []string{dataprovider.PermListItems} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // the user has no more the permission to create the directory req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+testDir, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) // the user is deleted, any API call should fail req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+testDir, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target="+testDir+"new", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path="+testDir+"&target="+testDir+"new", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path="+testDir+"new", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestWebUploadSingleFile(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) content := []byte("test content") req, err := http.NewRequest(http.MethodPost, userUploadFilePath, bytes.NewBuffer(content)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "please set a file path") modTime := time.Now().Add(-24 * time.Hour) req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.txt", bytes.NewBuffer(content)) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, webAPIToken) req.Header.Set("X-SFTPGO-MTIME", strconv.FormatInt(util.GetTimeAsMsSinceEpoch(modTime), 10)) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) info, err := os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) if assert.NoError(t, err) { assert.InDelta(t, util.GetTimeAsMsSinceEpoch(modTime), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(1000)) } // invalid modification time will be ignored req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.txt", bytes.NewBuffer(content)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) req.Header.Set("X-SFTPGO-MTIME", "123abc") rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) info, err = os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) if assert.NoError(t, err) { assert.InDelta(t, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(3000)) } // upload to a missing dir will fail without the mkdir_parents param req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path="+url.QueryEscape("/subdir/file.txt"), bytes.NewBuffer(content)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?mkdir_parents=true&path="+url.QueryEscape("/subdir/file.txt"), bytes.NewBuffer(content)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) metadataReq := make(map[string]int64) metadataReq["modification_time"] = util.GetTimeAsMsSinceEpoch(modTime) asJSON, err := json.Marshal(metadataReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file.txt", bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) info, err = os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) if assert.NoError(t, err) { assert.InDelta(t, util.GetTimeAsMsSinceEpoch(modTime), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(1000)) } // missing file req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file2.txt", bytes.NewBuffer(asJSON)) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to set metadata for path") // invalid JSON req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file.txt", bytes.NewBuffer(content)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // missing mandatory parameter req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "please set a modification_time and a path") metadataReq = make(map[string]int64) asJSON, err = json.Marshal(metadataReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file.txt", bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "please set a modification_time and a path") req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=%2Fdir%2Ffile.txt", bytes.NewBuffer(content)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to write file") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.txt", bytes.NewBuffer(content)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") metadataReq["modification_time"] = util.GetTimeAsMsSinceEpoch(modTime) asJSON, err = json.Marshal(metadataReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file.txt", bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") } func TestWebFilesAPI(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part1, err := writer.CreateFormFile("filenames", "file1.txt") assert.NoError(t, err) _, err = part1.Write([]byte("file1 content")) assert.NoError(t, err) part2, err := writer.CreateFormFile("filenames", "file2.txt") assert.NoError(t, err) _, err = part2.Write([]byte("file2 content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Unable to parse multipart form") _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, int64(0), user.FirstDownload) // set the proper content type req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.FirstUpload, int64(0)) assert.Equal(t, int64(0), user.FirstDownload) // check we have 2 files req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var contents []map[string]any err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 2) // download a file req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=file1.txt", nil) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "file1 content", rr.Body.String()) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.FirstUpload, int64(0)) assert.Greater(t, user.FirstDownload, int64(0)) // overwrite the existing files _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = nil err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 2) // now create a dir and upload to that dir testDir := "tdir" req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+testDir, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path="+testDir, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) // upload to a missing subdir will fail without the mkdir_parents param _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path="+url.QueryEscape("/sub/"+testDir), reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?mkdir_parents=true&path="+url.QueryEscape("/sub/"+testDir), reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = nil err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 4) req, err = http.NewRequest(http.MethodGet, userDirsPath+"?path="+testDir, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = nil err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 2) // copy a file req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path=file1.txt&target=%2Ftdir%2Ffile_copy.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // rename a file req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?target=%2Ftdir%2Ffile3.txt&path=file1.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // rename a missing file req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=file1.txt&target=%2Ftdir%2Ffile3.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // rename a file with target name equal to source name req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=file1.txt&target=file1.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "operation unsupported") // delete a file req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file2.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // delete a missing file req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file2.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // delete a directory req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=tdir", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // make a symlink outside the home dir and then try to delete it extPath := filepath.Join(os.TempDir(), "file") err = os.WriteFile(extPath, []byte("contents"), os.ModePerm) assert.NoError(t, err) err = os.Symlink(extPath, filepath.Join(user.GetHomeDir(), "file")) assert.NoError(t, err) req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) err = os.Remove(extPath) assert.NoError(t, err) // remove delete and overwrite permissions user.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=tdir", reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=%2Ftdir%2Ffile1.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) // the user is deleted, any API call should fail _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=file1.txt&target=%2Ftdir%2Ffile3.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file2.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestBufferedWebFilesAPI(t *testing.T) { u := getTestUser() u.FsConfig.OSConfig = sdk.OSFsConfig{ ReadBufferSize: 1, WriteBufferSize: 1, } vdirPath := "/crypted" mappedPath := filepath.Join(os.TempDir(), util.GenerateUniqueID()) folderName := filepath.Base(mappedPath) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, QuotaFiles: -1, QuotaSize: -1, }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ OSFsConfig: sdk.OSFsConfig{ WriteBufferSize: 3, ReadBufferSize: 2, }, Passphrase: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part1, err := writer.CreateFormFile("filenames", "file1.txt") assert.NoError(t, err) _, err = part1.Write([]byte("file1 content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path="+url.QueryEscape(vdirPath), reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=file1.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "file1 content", rr.Body.String()) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+url.QueryEscape(vdirPath+"/file1.txt"), nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, "file1 content", rr.Body.String()) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=file1.txt", nil) assert.NoError(t, err) req.Header.Set("Range", "bytes=2-") setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPartialContent, rr) assert.Equal(t, "le1 content", rr.Body.String()) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+url.QueryEscape(vdirPath+"/file1.txt"), nil) assert.NoError(t, err) req.Header.Set("Range", "bytes=3-6") setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPartialContent, rr) assert.Equal(t, "e1 c", rr.Body.String()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestWebClientTasksAPI(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) u1 := getTestUser() u1.Username = xid.New().String() user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) assert.NoError(t, err) testDir := "subdir" testFileData := []byte("data") testFilePath := filepath.Join(user.GetHomeDir(), testDir, "file.txt") testFileName := filepath.Base(testFilePath) err = os.MkdirAll(filepath.Dir(testFilePath), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFilePath, testFileData, 0666) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) webToken1, err := getJWTWebClientTokenFromTestServer(user1.Username, defaultPassword) assert.NoError(t, err) getStatusResponse := func(taskID string) int { req, _ := http.NewRequest(http.MethodGet, webClientTasksPath+"/"+url.PathEscape(taskID), nil) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr := executeRequest(req) if rr.Code != http.StatusOK { return -1 } resp := make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &resp) if err != nil { return -1 } return int(resp["status"].(float64)) } // missing task assert.Equal(t, -1, getStatusResponse("missing")) req, err := http.NewRequest(http.MethodPost, webClientFileCopyPath+"?path="+ url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) resp := make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) taskID := resp["message"].(string) assert.NotEmpty(t, taskID) assert.Eventually(t, func() bool { status := getStatusResponse(taskID) return status == http.StatusOK }, 1000*time.Millisecond, 100*time.Millisecond) // cannot get the task with a different user req, err = http.NewRequest(http.MethodGet, webClientTasksPath+"/"+url.PathEscape(taskID), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken1) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+ url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) resp = make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) taskID = resp["message"].(string) assert.NotEmpty(t, taskID) assert.Eventually(t, func() bool { status := getStatusResponse(taskID) return status == http.StatusOK }, 1000*time.Millisecond, 100*time.Millisecond) req, err = http.NewRequest(http.MethodDelete, webClientDirsPath+"?path="+ url.QueryEscape(testDir), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) resp = make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) taskID = resp["message"].(string) assert.NotEmpty(t, taskID) assert.Eventually(t, func() bool { status := getStatusResponse(taskID) return status == http.StatusOK }, 1000*time.Millisecond, 100*time.Millisecond) req, err = http.NewRequest(http.MethodDelete, webClientDirsPath+"?path="+ url.QueryEscape(testDir), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) resp = make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) taskID = resp["message"].(string) assert.NotEmpty(t, taskID) assert.Eventually(t, func() bool { status := getStatusResponse(taskID) return status == http.StatusNotFound }, 1000*time.Millisecond, 100*time.Millisecond) req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+ url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) resp = make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) taskID = resp["message"].(string) assert.NotEmpty(t, taskID) assert.Eventually(t, func() bool { status := getStatusResponse(taskID) return status == http.StatusNotFound }, 1000*time.Millisecond, 100*time.Millisecond) req, err = http.NewRequest(http.MethodPost, webClientFileCopyPath+"?path="+ url.QueryEscape(path.Join(testDir, testFileName)+"/")+"&target="+url.QueryEscape(testFileName+"/"), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusAccepted, rr) resp = make(map[string]any) err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) taskID = resp["message"].(string) assert.NotEmpty(t, taskID) assert.Eventually(t, func() bool { status := getStatusResponse(taskID) return status == http.StatusNotFound }, 1000*time.Millisecond, 100*time.Millisecond) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) // user deleted req, err = http.NewRequest(http.MethodDelete, webClientDirsPath+"?path="+ url.QueryEscape(testDir), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+ url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, webClientFileCopyPath+"?path="+ url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("X-CSRF-TOKEN", csrfToken) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestStartDirectory(t *testing.T) { u := getTestUser() u.Filters.StartDirectory = "/start/dir" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) filename := "file1.txt" body := new(bytes.Buffer) writer := multipart.NewWriter(body) part1, err := writer.CreateFormFile("filenames", filename) assert.NoError(t, err) _, err = part1.Write([]byte("test content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) // check we have 2 files in the defined start dir req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var contents []map[string]any err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) if assert.Len(t, contents, 1) { assert.Equal(t, filename, contents[0]["name"].(string)) } req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file2.txt", bytes.NewBuffer([]byte("single upload content"))) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path=testdir", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=testdir&target=testdir1", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path=%2Ftestdirroot", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodGet, userDirsPath+"?path="+url.QueryEscape(u.Filters.StartDirectory), nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = nil err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 3) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+filename, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=%2F"+filename, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPatch, userFilesPath+"?path="+filename+"&target="+filename+"_rename", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path=testdir1", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = nil err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 2) req, err = http.NewRequest(http.MethodGet, webClientDirsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = nil err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 2) req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path="+filename+"_rename", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, userDirsPath+"?path="+url.QueryEscape(u.Filters.StartDirectory), nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) contents = nil err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 1) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebFilesTransferQuotaLimits(t *testing.T) { u := getTestUser() u.UploadDataTransfer = 1 u.DownloadDataTransfer = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) testFileName := "file.data" testFileSize := 550000 testFileContents := make([]byte, testFileSize) n, err := io.ReadFull(rand.Reader, testFileContents) assert.NoError(t, err) assert.Equal(t, testFileSize, n) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", testFileName) assert.NoError(t, err) _, err = part.Write(testFileContents) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, testFileContents, rr.Body.Bytes()) // error while download is active downloadFunc := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } downloadFunc() // error before starting the download req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // error while upload is active _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) // error before starting the upload _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) // now test upload/download to/from shares share1 := dataprovider.Share{ Name: "share1", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, } asJSON, err := json.Marshal(share1) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) form := make(url.Values) form.Set("files", `[]`) req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/partial"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorQuotaRead) share2 := dataprovider.Share{ Name: "share2", Scope: dataprovider.ShareScopeWrite, Paths: []string{"/"}, } asJSON, err = json.Marshal(share2) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) objectID = rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebUploadErrors(t *testing.T) { u := getTestUser() u.QuotaSize = 65535 subDir1 := "sub1" subDir2 := "sub2" u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems} u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete} u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/sub2", AllowedPatterns: []string{}, DeniedPatterns: []string{"*.zip"}, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "file.zip") assert.NoError(t, err) _, err = part.Write([]byte("file content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) // zip file are not allowed within sub2 req, err := http.NewRequest(http.MethodPost, userFilesPath+"?path=sub2", reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) // we have no upload permissions within sub1 req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=sub1", reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // we cannot create dirs in sub2 _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?mkdir_parents=true&path="+url.QueryEscape("/sub2/dir"), reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "unable to check/create missing parent dir") req, err = http.NewRequest(http.MethodPost, userDirsPath+"?mkdir_parents=true&path="+url.QueryEscape("/sub2/dir/test"), nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Error checking parent directories") req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?mkdir_parents=true&path="+url.QueryEscape("/sub2/dir1/file.txt"), bytes.NewBuffer([]byte(""))) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Error checking parent directories") // create a dir and try to overwrite it with a file req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path=file.zip", nil) //nolint:goconst assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "operation unsupported") // try to upload to a missing parent directory _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=missingdir", reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path=file.zip", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // upload will work now _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) // overwrite the file _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) vfs.SetTempPath(filepath.Join(os.TempDir(), "missingpath")) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) if runtime.GOOS != osWindows { req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file.zip", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) vfs.SetTempPath(filepath.Clean(os.TempDir())) err = os.Chmod(user.GetHomeDir(), 0555) assert.NoError(t, err) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Error closing file") req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.zip", bytes.NewBuffer(nil)) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Error closing file") err = os.Chmod(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) } vfs.SetTempPath("") // upload a multipart form with no files body = new(bytes.Buffer) writer = multipart.NewWriter(body) err = writer.Close() assert.NoError(t, err) reader = bytes.NewReader(body.Bytes()) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=sub2", reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "No files uploaded!") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebAPIVFolder(t *testing.T) { u := getTestUser() u.QuotaSize = 65535 vdir := "/vdir" mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdir, QuotaSize: -1, QuotaFiles: -1, }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(user.Username, defaultPassword) assert.NoError(t, err) fileContents := []byte("test contents") body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "file.txt") assert.NoError(t, err) _, err = part.Write(fileContents) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath+"?path=vdir", reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(len(fileContents)), user.UsedQuotaSize) folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), folder.UsedQuotaSize) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=vdir", reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(len(fileContents)), user.UsedQuotaSize) folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), folder.UsedQuotaSize) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestWebAPIWritePermission(t *testing.T) { u := getTestUser() u.Filters.WebClient = append(u.Filters.WebClient, sdk.WebClientWriteDisabled) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "file.txt") assert.NoError(t, err) _, err = part.Write([]byte("")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=a&target=b", nil) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=a", nil) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=a.txt", nil) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path=dir", nil) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=dir&target=dir1", nil) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path=dir", nil) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebAPICryptFs(t *testing.T) { u := getTestUser() u.QuotaSize = 65535 u.FsConfig.Provider = sdk.CryptedFilesystemProvider u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "file.txt") assert.NoError(t, err) _, err = part.Write([]byte("content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebUploadSFTP(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaFiles = 100 u.FsConfig.SFTPConfig.BufferSize = 2 u.HomeDir = filepath.Join(os.TempDir(), u.Username) sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(sftpUser.Username, defaultPassword) assert.NoError(t, err) body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "file.txt") assert.NoError(t, err) _, err = part.Write([]byte("test file content")) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) expectedQuotaSize := int64(17) expectedQuotaFiles := 1 user, _, err := httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) user.QuotaSize = 10 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // we are now overquota on overwrite _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) assert.Contains(t, rr.Body.String(), "denying write due to space limit") assert.Contains(t, rr.Body.String(), "Unable to write file") // delete the file req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user, _, err = httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.txt", bytes.NewBuffer([]byte("test upload single file content"))) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) assert.Contains(t, rr.Body.String(), "denying write due to space limit") assert.Contains(t, rr.Body.String(), "Error saving file") // delete the file req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = reader.Seek(0, io.SeekStart) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) assert.Contains(t, rr.Body.String(), "denying write due to space limit") assert.Contains(t, rr.Body.String(), "Error saving file") // delete the file req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file.txt", nil) assert.NoError(t, err) setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user, _, err = httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(sftpUser.GetHomeDir()) assert.NoError(t, err) } func TestWebAPISFTPPasswordProtectedPrivateKey(t *testing.T) { u := getTestUser() u.Password = "" u.PublicKeys = []string{testPubKeyPwd} localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(testPrivateKeyPwd) u.FsConfig.SFTPConfig.KeyPassphrase = kms.NewPlainSecret(privateKeyPwd) u.HomeDir = filepath.Join(os.TempDir(), u.Username) sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(sftpUser.Username, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // update the user, the key must be preserved assert.Equal(t, sdkkms.SecretStatusSecretBox, sftpUser.FsConfig.SFTPConfig.KeyPassphrase.GetStatus()) _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // using a wrong passphrase or no passphrase should fail sftpUser.FsConfig.SFTPConfig.KeyPassphrase = kms.NewPlainSecret("wrong") _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(sftpUser.Username, defaultPassword) assert.Error(t, err) sftpUser.FsConfig.SFTPConfig.KeyPassphrase = kms.NewEmptySecret() _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(sftpUser.Username, defaultPassword) assert.Error(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(sftpUser.GetHomeDir()) assert.NoError(t, err) } func TestWebUploadMultipartFormReadError(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, userFilesPath, nil) assert.NoError(t, err) mpartForm := &multipart.Form{ File: make(map[string][]*multipart.FileHeader), } mpartForm.File["filenames"] = append(mpartForm.File["filenames"], &multipart.FileHeader{Filename: "missing"}) req.MultipartForm = mpartForm req.Header.Add("Content-Type", "multipart/form-data") setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) assert.Contains(t, rr.Body.String(), "Unable to read uploaded file") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestCompressionErrorMock(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) _, err := httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) }() webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("files", `["missing"]`) req, _ := http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path=%2F", bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) executeRequest(req) } func TestGetFilesSFTPBackend(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) u := getTestSFTPUser() u.HomeDir = filepath.Clean(os.TempDir()) u.FsConfig.SFTPConfig.BufferSize = 2 u.Permissions["/adir"] = nil u.Permissions["/adir1"] = []string{dataprovider.PermListItems} u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/adir2", DeniedPatterns: []string{"*.txt"}, }, } sftpUserBuffered, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u.Username += "_unbuffered" u.FsConfig.SFTPConfig.BufferSize = 0 sftpUserUnbuffered, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileName := "testsftpfile" testDir := "testsftpdir" testFileContents := []byte("sftp file contents") err = os.MkdirAll(filepath.Join(user.GetHomeDir(), testDir, "sub"), os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "adir1"), os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "adir2"), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName), testFileContents, os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "adir1", "afile"), testFileContents, os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "adir2", "afile.txt"), testFileContents, os.ModePerm) assert.NoError(t, err) for _, sftpUser := range []dataprovider.User{sftpUserBuffered, sftpUserUnbuffered} { webToken, err := getJWTWebClientTokenFromTestServer(sftpUser.Username, defaultPassword) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+path.Join(testDir, "sub"), nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) htmlErrFrag := `div id="errorMsg" class="rounded border-warning border border-dashed bg-light-warning` req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+path.Join(testDir, "missing"), nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), htmlErrFrag) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=adir/sub", nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), htmlErrFrag) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=adir1/afile", nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), htmlErrFrag) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=adir2/afile.txt", nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), htmlErrFrag) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, testFileContents, rr.Body.Bytes()) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) req.Header.Set("Range", "bytes=2-") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusPartialContent, rr) assert.Equal(t, testFileContents[2:], rr.Body.Bytes()) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestClientUserClose(t *testing.T) { u := getTestUser() u.UploadBandwidth = 32 u.DownloadBandwidth = 32 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileName := "file.dat" testFileSize := int64(524288) testFilePath := filepath.Join(user.GetHomeDir(), testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) uploadContent := make([]byte, testFileSize) _, err = rand.Read(uploadContent) assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, _ := http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) }() wg.Add(1) go func() { defer wg.Done() req, _ := http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFileName, nil) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) }() wg.Add(1) go func() { defer wg.Done() body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("filenames", "upload.dat") assert.NoError(t, err) n, err := part.Write(uploadContent) assert.NoError(t, err) assert.Equal(t, testFileSize, int64(n)) err = writer.Close() assert.NoError(t, err) reader := bytes.NewReader(body.Bytes()) req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) assert.NoError(t, err) req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "transfer aborted") }() // wait for the transfers assert.Eventually(t, func() bool { stats := common.Connections.GetStats("") if len(stats) == 3 { if len(stats[0].Transfers) > 0 && len(stats[1].Transfers) > 0 { return true } } return false }, 1*time.Second, 50*time.Millisecond) for _, stat := range common.Connections.GetStats("") { // close all the active transfers common.Connections.Close(stat.ConnectionID, "") } wg.Wait() assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWebAdminSetupMock(t *testing.T) { req, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil) assert.NoError(t, err) rr := executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) // now delete all the admins admins, err := dataprovider.GetAdmins(100, 0, dataprovider.OrderASC) assert.NoError(t, err) for _, admin := range admins { err = dataprovider.DeleteAdmin(admin.Username, "", "", "") assert.NoError(t, err) } // close the provider and initializes it without creating the default admin os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "0") err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) // now the setup page must be rendered req, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // check redirects to the setup page req, err = http.NewRequest(http.MethodGet, "/", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) req, err = http.NewRequest(http.MethodGet, webBasePath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) req, err = http.NewRequest(http.MethodGet, webBasePathAdmin, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) req, err = http.NewRequest(http.MethodGet, webLoginPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) loginCookie, csrfToken, err := getCSRFTokenMock(webAdminSetupPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("username", defaultTokenAuthUser) req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("password", defaultTokenAuthPass) req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdNoMatch) form.Set("confirm_password", defaultTokenAuthPass) // test a parse form error req, err = http.NewRequest(http.MethodPost, webAdminSetupPath+"?param=p%C3%AO%GH", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test a dataprovider error err = dataprovider.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // finally initialize the provider and create the default admin err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webAdminMFAPath, rr.Header().Get("Location")) // if we resubmit the form we get a bad request, an admin already exists req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nError400Message) os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") } func TestAllowList(t *testing.T) { configCopy := common.Config entries := []dataprovider.IPListEntry{ { IPOrNet: "172.120.1.1/32", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 0, }, { IPOrNet: "172.120.1.2/32", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 0, }, { IPOrNet: "192.8.7.0/22", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 8, }, } for _, e := range entries { _, _, err := httpdtest.AddIPListEntry(e, http.StatusCreated) assert.NoError(t, err) } common.Config.MaxTotalConnections = 1 common.Config.AllowListStatus = 1 err := common.Initialize(common.Config, 0) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webLoginPath, nil) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) req.RemoteAddr = "172.120.1.1" rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) testIP := "172.120.1.3" req.RemoteAddr = testIP rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) req.RemoteAddr = "192.8.7.1" rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) entry := dataprovider.IPListEntry{ IPOrNet: "172.120.1.3/32", Type: dataprovider.IPListTypeAllowList, Mode: dataprovider.ListModeAllow, Protocols: 8, } err = dataprovider.AddIPListEntry(&entry, "", "", "") assert.NoError(t, err) req.RemoteAddr = testIP rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = dataprovider.DeleteIPListEntry(entry.IPOrNet, entry.Type, "", "", "") assert.NoError(t, err) req.RemoteAddr = testIP rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) common.Config = configCopy err = common.Initialize(common.Config, 0) assert.NoError(t, err) for _, e := range entries { _, err := httpdtest.RemoveIPListEntry(e, http.StatusOK) assert.NoError(t, err) } } func TestWebAdminLoginMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webStatusPath+"notfound", nil) req.RequestURI = webStatusPath + "notfound" setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webLogoutPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) cookie := rr.Header().Get("Cookie") assert.Empty(t, cookie) req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) req, _ = http.NewRequest(http.MethodGet, logoutPath, nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "Your token is no longer valid") req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) // now try using wrong password form := getLoginForm(defaultTokenAuthUser, "wrong pwd", csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) // wrong username form = getLoginForm("wrong username", defaultTokenAuthPass, csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) // try from an ip not allowed a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword a.Filters.AllowList = []string{"10.0.0.0/8"} _, _, err = httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) rAddr := "127.1.1.1:1234" loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = rAddr rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) rAddr = "10.9.9.9:1234" loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = rAddr setLoginCookie(req, loginCookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) rAddr = "127.0.1.1:4567" loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = rAddr req.Header.Set("X-Forwarded-For", "10.9.9.9") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) // invalid csrf token form = getLoginForm(altAdminUsername, altAdminPassword, "invalid csrf") req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = "10.9.9.8:1234" rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) req, _ = http.NewRequest(http.MethodGet, webLoginPath, nil) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveAdmin(a, http.StatusOK) assert.NoError(t, err) } func TestAdminNoToken(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, webAdminProfilePath, nil) rr := executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodGet, userPath, nil) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) req, _ = http.NewRequest(http.MethodGet, activeConnectionsPath, nil) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) } func TestWebUserShare(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) assert.NoError(t, err) userAPItoken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "test share", Description: "test share desc", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour)), MaxTokens: 100, AllowFrom: []string{"127.0.0.0/8", "172.16.0.0/16"}, Password: defaultPassword, } form := make(url.Values) form.Set("name", share.Name) form.Set("scope", strconv.Itoa(int(share.Scope))) form.Set("paths[0][path]", "/") form.Set("max_tokens", strconv.Itoa(share.MaxTokens)) form.Set("allowed_ip", strings.Join(share.AllowFrom, ",")) form.Set("description", share.Description) form.Set("password", share.Password) form.Set("expiration_date", "123") // invalid expiration date req, err := http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareExpiration) form.Set("expiration_date", util.GetTimeFromMsecSinceEpoch(share.ExpiresAt).UTC().Format("2006-01-02 15:04:05")) form.Set("scope", "") // invalid scope req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareScope) form.Set("scope", strconv.Itoa(int(share.Scope))) // invalid max tokens form.Set("max_tokens", "t") req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareMaxTokens) form.Set("max_tokens", strconv.Itoa(share.MaxTokens)) // no csrf token req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) form.Set("scope", "100") req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareScope) form.Set("scope", strconv.Itoa(int(share.Scope))) req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) assert.NoError(t, err) setBearerForReq(req, userAPItoken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var shares []dataprovider.Share err = json.Unmarshal(rr.Body.Bytes(), &shares) assert.NoError(t, err) if assert.Len(t, shares, 1) { s := shares[0] assert.Equal(t, share.Name, s.Name) assert.Equal(t, share.Description, s.Description) assert.Equal(t, share.Scope, s.Scope) assert.Equal(t, share.Paths, s.Paths) assert.InDelta(t, share.ExpiresAt, s.ExpiresAt, 999) assert.Equal(t, share.MaxTokens, s.MaxTokens) assert.Equal(t, share.AllowFrom, s.AllowFrom) assert.Equal(t, redactedSecret, s.Password) share.ShareID = s.ShareID } form.Set("password", redactedSecret) form.Set("expiration_date", "123") req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/unknowid", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareExpiration) form.Set("expiration_date", "") form.Set(csrfFormToken, "") req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) form.Set("allowed_ip", "1.1.1") req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidIPMask) form.Set("allowed_ip", "") req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) assert.NoError(t, err) setBearerForReq(req, userAPItoken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) shares = nil err = json.Unmarshal(rr.Body.Bytes(), &shares) assert.NoError(t, err) if assert.Len(t, shares, 1) { s := shares[0] assert.Equal(t, share.Name, s.Name) assert.Equal(t, share.Description, s.Description) assert.Equal(t, share.Scope, s.Scope) assert.Equal(t, share.Paths, s.Paths) assert.Equal(t, int64(0), s.ExpiresAt) assert.Equal(t, share.MaxTokens, s.MaxTokens) assert.Empty(t, s.AllowFrom) } // check the password s, err := dataprovider.ShareExists(share.ShareID, user.Username) assert.NoError(t, err) match, err := s.CheckCredentials(defaultPassword) assert.NoError(t, err) assert.True(t, match) req, err = http.NewRequest(http.MethodGet, webClientSharePath+"?path=%2F&files=a", nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nError400Message) req, err = http.NewRequest(http.MethodGet, webClientSharePath+"?path=%2F&files=%5B\"adir\"%5D", nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientSharePath+"/unknown", nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webClientSharePath+"/"+share.ShareID, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientSharesPath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientSharesPath+jsonAPISuffix, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestWebUserShareNoPasswordDisabled(t *testing.T) { u := getTestUser() u.Filters.WebClient = []string{sdk.WebClientShareNoPasswordDisabled} u.Filters.DefaultSharesExpiration = 15 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Filters.DefaultSharesExpiration = 30 user, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") assert.NoError(t, err) token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientSharePath, token) assert.NoError(t, err) userAPItoken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) share := dataprovider.Share{ Name: "s", Scope: dataprovider.ShareScopeRead, Paths: []string{"/"}, } form := make(url.Values) form.Set("name", share.Name) form.Set("scope", strconv.Itoa(int(share.Scope))) form.Set("paths[0][path]", "/") form.Set("max_tokens", "0") form.Set(csrfFormToken, csrfToken) req, err := http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareNoPwd) form.Set("password", defaultPassword) req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, err = http.NewRequest(http.MethodGet, webClientSharePath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) assert.NoError(t, err) setBearerForReq(req, userAPItoken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var shares []dataprovider.Share err = json.Unmarshal(rr.Body.Bytes(), &shares) assert.NoError(t, err) if assert.Len(t, shares, 1) { s := shares[0] assert.Equal(t, share.Name, s.Name) assert.Equal(t, share.Scope, s.Scope) assert.Equal(t, share.Paths, s.Paths) share.ShareID = s.ShareID } assert.NotEmpty(t, share.ShareID) form.Set("password", "") req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorShareNoPwd) user.Filters.DefaultSharesExpiration = 0 user.Filters.MaxSharesExpiration = 30 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientSharePath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestInvalidCSRF(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) for _, loginURL := range []string{webClientLoginPath, webLoginPath} { // try using an invalid CSRF token loginCookie1, csrfToken1, err := getCSRFTokenMock(loginURL, defaultRemoteAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie1) assert.NotEmpty(t, csrfToken1) loginCookie2, csrfToken2, err := getCSRFTokenMock(loginURL, defaultRemoteAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie2) assert.NotEmpty(t, csrfToken2) rAddr := "1.2.3.4" loginCookie3, csrfToken3, err := getCSRFTokenMock(loginURL, rAddr) assert.NoError(t, err) assert.NotEmpty(t, loginCookie3) assert.NotEmpty(t, csrfToken3) form := getLoginForm(defaultUsername, defaultPassword, csrfToken1) req, err := http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = loginURL setLoginCookie(req, loginCookie2) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) // use a CSRF token as login cookie (invalid audience) form = getLoginForm(defaultUsername, defaultPassword, csrfToken1) req, err = http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = loginURL req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", csrfToken1)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) // invalid IP form = getLoginForm(defaultUsername, defaultPassword, csrfToken3) req, err = http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = loginURL setLoginCookie(req, loginCookie3) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) } err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestWebUserProfile(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) assert.NoError(t, err) email := "user@user.com" description := "User" form := make(url.Values) form.Set("allow_api_key_auth", "1") form.Set("email", email) form.Set("description", description) form.Set("public_keys[0][public_key]", testPubKey) form.Set("public_keys[1][public_key]", testPubKey1) form.Set("tls_certs[0][tls_cert]", httpsCert) form.Set("additional_emails[0][additional_email]", "email1@user.com") // no csrf token req, err := http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.AllowAPIKeyAuth) assert.Len(t, user.PublicKeys, 2) assert.Len(t, user.Filters.TLSCerts, 1) assert.Equal(t, email, user.Email) assert.Equal(t, description, user.Description) if assert.Len(t, user.Filters.AdditionalEmails, 1) { assert.Equal(t, "email1@user.com", user.Filters.AdditionalEmails[0]) } // set an invalid email form.Set("email", "not an email") req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidEmail) // invalid tls cert form.Set("email", email) form.Set("tls_certs[0][tls_cert]", "not a TLS cert") req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidTLSCert) // invalid public key form.Set("tls_certs[0][tls_cert]", httpsCert) form.Set("public_keys[0][public_key]", "invalid") req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorPubKeyInvalid) // now remove permissions form.Set("public_keys[0][public_key]", testPubKey) form.Del("public_keys[1][public_key]") user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) assert.NoError(t, err) form.Set("allow_api_key_auth", "0") form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.AllowAPIKeyAuth) assert.Len(t, user.PublicKeys, 1) assert.Len(t, user.Filters.TLSCerts, 1) assert.Equal(t, email, user.Email) assert.Equal(t, description, user.Description) user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientPubKeyChangeDisabled, sdk.WebClientTLSCertChangeDisabled} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) assert.NoError(t, err) form.Set("public_keys[0][public_key]", testPubKey) form.Set("public_keys[1][public_key]", testPubKey1) form.Set("tls_certs[0][tls_cert]", "") form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.AllowAPIKeyAuth) assert.Len(t, user.PublicKeys, 1) assert.Len(t, user.Filters.TLSCerts, 1) assert.Equal(t, email, user.Email) assert.Equal(t, description, user.Description) user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) assert.NoError(t, err) form.Set("email", "newemail@user.com") form.Set("description", "new description") form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.AllowAPIKeyAuth) assert.Len(t, user.PublicKeys, 2) assert.Len(t, user.Filters.TLSCerts, 0) assert.Equal(t, email, user.Email) assert.Equal(t, description, user.Description) // finally disable all profile permissions user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled, sdk.WebClientPubKeyChangeDisabled, sdk.WebClientTLSCertChangeDisabled} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) csrfToken, err = getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, token) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) } func TestWebAdminProfile(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) token, err := getJWTWebTokenFromTestServer(admin.Username, altAdminPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminProfilePath, token) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webAdminProfilePath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form := make(url.Values) form.Set("allow_api_key_auth", "1") form.Set("email", "admin@example.com") form.Set("description", "admin desc") // no csrf token req, err = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.AllowAPIKeyAuth) assert.Equal(t, "admin@example.com", admin.Email) assert.Equal(t, "admin desc", admin.Description) form = make(url.Values) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.AllowAPIKeyAuth) assert.Empty(t, admin.Email) assert.Empty(t, admin.Description) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) } func TestWebAdminPwdChange(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin.Filters.Preferences.HideUserPageSections = 16 + 32 admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) token, err := getJWTWebTokenFromTestServer(admin.Username, altAdminPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeAdminPwdPath, token) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webChangeAdminPwdPath, nil) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form := make(url.Values) form.Set("current_password", altAdminPassword) form.Set("new_password1", altAdminPassword) form.Set("new_password2", altAdminPassword) // no csrf token req, _ = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdNoDifferent) form.Set("new_password1", altAdminPassword+"1") form.Set("new_password2", altAdminPassword+"1") req, _ = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestAPIKeysManagement(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiKey := dataprovider.APIKey{ Name: "test key", Scope: dataprovider.APIKeyScopeAdmin, } asJSON, err := json.Marshal(apiKey) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) location := rr.Header().Get("Location") assert.NotEmpty(t, location) objectID := rr.Header().Get("X-Object-ID") assert.NotEmpty(t, objectID) assert.Equal(t, fmt.Sprintf("%v/%v", apiKeysPath, objectID), location) apiKey.KeyID = objectID response := make(map[string]string) err = json.Unmarshal(rr.Body.Bytes(), &response) assert.NoError(t, err) key := response["key"] assert.NotEmpty(t, key) assert.True(t, strings.HasPrefix(key, apiKey.KeyID+".")) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var keyGet dataprovider.APIKey err = json.Unmarshal(rr.Body.Bytes(), &keyGet) assert.NoError(t, err) assert.Empty(t, keyGet.Key) assert.Equal(t, apiKey.KeyID, keyGet.KeyID) assert.Equal(t, apiKey.Scope, keyGet.Scope) assert.Equal(t, apiKey.Name, keyGet.Name) assert.Equal(t, int64(0), keyGet.ExpiresAt) assert.Equal(t, int64(0), keyGet.LastUseAt) assert.Greater(t, keyGet.CreatedAt, int64(0)) assert.Greater(t, keyGet.UpdatedAt, int64(0)) assert.Empty(t, keyGet.Description) assert.Empty(t, keyGet.User) assert.Empty(t, keyGet.Admin) // API key is not enabled for the admin user so this request should fail req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, key, admin.Username) rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "the admin associated with the provided api key cannot be authenticated") admin.Filters.AllowAPIKeyAuth = true admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, key, admin.Username) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, key, admin.Username+"1") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) // now associate the key directly to the admin apiKey.Admin = admin.Username apiKey.Description = "test description" asJSON, err = json.Marshal(apiKey) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, apiKeysPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var keys []dataprovider.APIKey err = json.Unmarshal(rr.Body.Bytes(), &keys) assert.NoError(t, err) if assert.GreaterOrEqual(t, len(keys), 1) { found := false for _, k := range keys { if k.KeyID == apiKey.KeyID { found = true assert.Empty(t, k.Key) assert.Equal(t, apiKey.Scope, k.Scope) assert.Equal(t, apiKey.Name, k.Name) assert.Equal(t, int64(0), k.ExpiresAt) assert.Greater(t, k.LastUseAt, int64(0)) assert.Equal(t, k.CreatedAt, keyGet.CreatedAt) assert.Greater(t, k.UpdatedAt, keyGet.UpdatedAt) assert.Equal(t, apiKey.Description, k.Description) assert.Empty(t, k.User) assert.Equal(t, admin.Username, k.Admin) } } assert.True(t, found) } req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // invalid API keys req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, key+"invalid", "") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) assert.Contains(t, rr.Body.String(), "the provided api key cannot be authenticated") req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, "invalid", "") rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // using an API key we cannot modify/get API keys req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setAPIKeyForReq(req, key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) setAPIKeyForReq(req, key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) admin.Filters.AllowList = []string{"172.16.18.0/24"} admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) req, err = http.NewRequest(http.MethodDelete, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, versionPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the provided api key is not valid") _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestAPIKeySearch(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiKey := dataprovider.APIKey{ Scope: dataprovider.APIKeyScopeAdmin, } for i := 1; i < 5; i++ { apiKey.Name = fmt.Sprintf("testapikey%v", i) asJSON, err := json.Marshal(apiKey) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) } req, err := http.NewRequest(http.MethodGet, apiKeysPath+"?limit=1&order=ASC", nil) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var keys []dataprovider.APIKey err = json.Unmarshal(rr.Body.Bytes(), &keys) assert.NoError(t, err) assert.Len(t, keys, 1) firstKey := keys[0] req, err = http.NewRequest(http.MethodGet, apiKeysPath+"?limit=1&order=DESC", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) keys = nil err = json.Unmarshal(rr.Body.Bytes(), &keys) assert.NoError(t, err) if assert.Len(t, keys, 1) { assert.NotEqual(t, firstKey.KeyID, keys[0].KeyID) } req, err = http.NewRequest(http.MethodGet, apiKeysPath+"?limit=1&offset=100", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) keys = nil err = json.Unmarshal(rr.Body.Bytes(), &keys) assert.NoError(t, err) assert.Len(t, keys, 0) req, err = http.NewRequest(http.MethodGet, apiKeysPath+"?limit=f", nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v/%v", apiKeysPath, "missingid"), nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, apiKeysPath, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) keys = nil err = json.Unmarshal(rr.Body.Bytes(), &keys) assert.NoError(t, err) counter := 0 for _, k := range keys { if strings.HasPrefix(k.Name, "testapikey") { req, err = http.NewRequest(http.MethodDelete, fmt.Sprintf("%v/%v", apiKeysPath, k.KeyID), nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) counter++ } } assert.Equal(t, 4, counter) } func TestAPIKeyErrors(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiKey := dataprovider.APIKey{ Name: "testkey", Scope: dataprovider.APIKeyScopeUser, } asJSON, err := json.Marshal(apiKey) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) location := rr.Header().Get("Location") assert.NotEmpty(t, location) // invalid API scope apiKey.Scope = 1000 asJSON, err = json.Marshal(apiKey) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // invalid JSON req, err = http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer([]byte(`invalid JSON`))) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer([]byte(`invalid JSON`))) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodDelete, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodDelete, location, nil) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestAPIKeyOnDeleteCascade(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin, _, err = httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) apiKey := dataprovider.APIKey{ Name: "user api key", Scope: dataprovider.APIKeyScopeUser, User: user.Username, } apiKey, _, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr := executeRequest(req) checkResponseCode(t, http.StatusUnauthorized, rr) user.Filters.AllowAPIKeyAuth = true user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) assert.NoError(t, err) setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var contents []map[string]any err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 0) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusNotFound) assert.NoError(t, err) apiKey.User = "" apiKey.Admin = admin.Username apiKey.Scope = dataprovider.APIKeyScopeAdmin apiKey, _, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusNotFound) assert.NoError(t, err) } func TestBasicWebUsersMock(t *testing.T) { token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) user1 := getTestUser() user1.Username += "1" user1AsJSON := getUserAsJSON(t, user1) req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(user1AsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user1) assert.NoError(t, err) webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, webUsersPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webUsersPath+jsonAPISuffix, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(webUserPath, user.Username), nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webUserPath+"/0", nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set("username", user.Username) form.Set(csrfFormToken, csrfToken) b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath+"/0", &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, _ = http.NewRequest(http.MethodPost, webUserPath+"/aaa", &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(webUserPath, user.Username), nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") req, _ = http.NewRequest(http.MethodDelete, path.Join(webUserPath, user.Username), nil) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(webUserPath, user1.Username), nil) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestRenderDefenderPageMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webDefenderPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nDefenderTitle) } func TestWebAdminBasicMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) assert.NoError(t, err) form := make(url.Values) form.Set("username", admin.Username) form.Set("password", "") form.Set("status", "1") form.Set("permissions", "*") form.Set("description", admin.Description) form.Set("user_page_hidden_sections", "1") form.Add("user_page_hidden_sections", "2") form.Add("user_page_hidden_sections", "3") form.Add("user_page_hidden_sections", "4") form.Add("user_page_hidden_sections", "5") form.Add("user_page_hidden_sections", "6") form.Add("user_page_hidden_sections", "7") form.Set("default_users_expiration", "10") req, _ := http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) form.Set("status", "a") req, _ = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("status", "1") form.Set("default_users_expiration", "a") req, _ = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("default_users_expiration", "10") form.Set("password", admin.Password) req, _ = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // add TOTP config configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], altAdminUsername) assert.NoError(t, err) altToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) adminTOTPConfig := dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), } asJSON, err := json.Marshal(adminTOTPConfig) assert.NoError(t, err) // no CSRF token req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, altToken) setCSRFHeaderForReq(req, csrfToken) // invalid CSRF token req.RemoteAddr = defaultRemoteAddr rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "the token is not valid") csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminPath, altToken) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, altToken) setCSRFHeaderForReq(req, csrfToken) req.RemoteAddr = defaultRemoteAddr rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) secretPayload := admin.Filters.TOTPConfig.Secret.GetPayload() assert.NotEmpty(t, secretPayload) assert.Equal(t, 1+2+4+8+16+32+64, admin.Filters.Preferences.HideUserPageSections) assert.Equal(t, 10, admin.Filters.Preferences.DefaultUsersExpiration) adminTOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewEmptySecret(), } asJSON, err = json.Marshal(adminTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, altToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) assert.Equal(t, secretPayload, admin.Filters.TOTPConfig.Secret.GetPayload()) adminTOTPConfig = dataprovider.AdminTOTPConfig{ Enabled: true, ConfigName: configName, Secret: nil, } asJSON, err = json.Marshal(adminTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, altToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webAdminsPath+jsonAPISuffix, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webAdminsPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webAdminPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("password", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) form.Set(csrfFormToken, csrfToken) // associated to altToken req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminPath, token) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("email", "not-an-email") req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("email", "") form.Set("status", "b") req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("email", "admin@example.com") form.Set("status", "0") req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) assert.NoError(t, err) assert.True(t, admin.Filters.TOTPConfig.Enabled) assert.Equal(t, "admin@example.com", admin.Email) assert.Equal(t, 0, admin.Status) req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername+"1"), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(webAdminPath, altAdminUsername), nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(webAdminPath, altAdminUsername+"1"), nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(webAdminPath, altAdminUsername), nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) _, err = httpdtest.RemoveAdmin(admin, http.StatusNotFound) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodDelete, path.Join(webAdminPath, defaultTokenAuthUser), nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "you cannot delete yourself") req, _ = http.NewRequest(http.MethodDelete, path.Join(webAdminPath, defaultTokenAuthUser), nil) req.RemoteAddr = defaultRemoteAddr setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") } func TestWebAdminGroupsMock(t *testing.T) { group1 := getTestGroup() group1.Name += "_1" group1, _, err := httpdtest.AddGroup(group1, http.StatusCreated) assert.NoError(t, err) group2 := getTestGroup() group2.Name += "_2" group2, _, err = httpdtest.AddGroup(group2, http.StatusCreated) assert.NoError(t, err) group3 := getTestGroup() group3.Name += "_3" group3, _, err = httpdtest.AddGroup(group3, http.StatusCreated) assert.NoError(t, err) token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", admin.Username) form.Set("password", "") form.Set("status", "1") form.Set("permissions", "*") form.Set("description", admin.Description) form.Set("password", admin.Password) form.Set("groups[0][group]", group1.Name) form.Set("groups[0][group_type]", "1") form.Set("groups[1][group]", group2.Name) form.Set("groups[1][group_type]", "2") form.Set("groups[2][group]", group3.Name) req, err := http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) if assert.Len(t, admin.Groups, 3) { for _, g := range admin.Groups { switch g.Name { case group1.Name: assert.Equal(t, dataprovider.GroupAddToUsersAsPrimary, g.Options.AddToUsersAs) case group2.Name: assert.Equal(t, dataprovider.GroupAddToUsersAsSecondary, g.Options.AddToUsersAs) case group3.Name: assert.Equal(t, dataprovider.GroupAddToUsersAsMembership, g.Options.AddToUsersAs) default: t.Errorf("unexpected group %q", g.Name) } } } adminToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webUserPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webUserPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, adminToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group3, http.StatusOK) assert.NoError(t, err) } func TestWebAdminPermissions(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin.Permissions = []string{dataprovider.PermAdminAddUsers} _, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) token, err := getJWTWebToken(altAdminUsername, altAdminPassword) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, httpBaseURL+webUserPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) resp, err := httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) req, err = http.NewRequest(http.MethodGet, httpBaseURL+path.Join(webUserPath, "auser"), nil) assert.NoError(t, err) setJWTCookieForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusForbidden, resp.StatusCode) req, err = http.NewRequest(http.MethodGet, httpBaseURL+webFolderPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusForbidden, resp.StatusCode) req, err = http.NewRequest(http.MethodGet, httpBaseURL+webStatusPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusForbidden, resp.StatusCode) req, err = http.NewRequest(http.MethodGet, httpBaseURL+webConnectionsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusForbidden, resp.StatusCode) req, err = http.NewRequest(http.MethodGet, httpBaseURL+webAdminPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusForbidden, resp.StatusCode) req, err = http.NewRequest(http.MethodGet, httpBaseURL+path.Join(webAdminPath, "a"), nil) assert.NoError(t, err) setJWTCookieForReq(req, token) resp, err = httpclient.GetHTTPClient().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusForbidden, resp.StatusCode) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestAdminUpdateSelfMock(t *testing.T) { admin, _, err := httpdtest.GetAdminByUsername(defaultTokenAuthUser, http.StatusOK) assert.NoError(t, err) token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) assert.NoError(t, err) form := make(url.Values) form.Set("username", admin.Username) form.Set("password", admin.Password) form.Set("status", "0") form.Set("permissions", dataprovider.PermAdminAddUsers) form.Set("permissions", dataprovider.PermAdminCloseConnections) form.Set(csrfFormToken, csrfToken) req, _ := http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorAdminSelfPerms) form.Set("permissions", dataprovider.PermAdminAny) req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorAdminSelfDisable) form.Set("status", "1") form.Set("require_two_factor", "1") form.Set("require_password_change", "1") req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) admin, _, err = httpdtest.GetAdminByUsername(defaultTokenAuthUser, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.RequirePasswordChange) assert.False(t, admin.Filters.RequireTwoFactor) form.Set("role", "my role") req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorAdminSelfRole) } func TestWebMaintenanceMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webMaintenancePath, nil) setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) csrfToken, err := getCSRFTokenFromInternalPageMock(webMaintenancePath, token) assert.NoError(t, err) form := make(url.Values) form.Set("mode", "a") b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("mode", "0") form.Set("quota", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("quota", "0") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodPost, webRestorePath+"?a=%3", &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) backupFilePath := filepath.Join(os.TempDir(), "backup.json") err = createTestFile(backupFilePath, 0) assert.NoError(t, err) b, contentType, _ = getMultipartFormData(form, "backup_file", backupFilePath) req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = createTestFile(backupFilePath, 10) assert.NoError(t, err) b, contentType, _ = getMultipartFormData(form, "backup_file", backupFilePath) req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user := getTestUser() user.ID = 1 user.Username = "test_user_web_restore" admin := getTestAdmin() admin.ID = 1 admin.Username = "test_admin_web_restore" apiKey := dataprovider.APIKey{ Name: "key name", KeyID: util.GenerateUniqueID(), Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), Scope: dataprovider.APIKeyScopeAdmin, } backupData := dataprovider.BackupData{ Version: dataprovider.DumpVersion, } backupData.Users = append(backupData.Users, user) backupData.Admins = append(backupData.Admins, admin) backupData.APIKeys = append(backupData.APIKeys, apiKey) backupContent, err := json.Marshal(backupData) assert.NoError(t, err) err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) assert.NoError(t, err) b, contentType, _ = getMultipartFormData(form, "backup_file", backupFilePath) req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nBackupOK) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) assert.NoError(t, err) err = os.Remove(backupFilePath) assert.NoError(t, err) } func TestWebUserAddMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) group1 := getTestGroup() group1.Name += "_1" group1, _, err = httpdtest.AddGroup(group1, http.StatusCreated) assert.NoError(t, err) group2 := getTestGroup() group2.Name += "_2" group2, _, err = httpdtest.AddGroup(group2, http.StatusCreated) assert.NoError(t, err) group3 := getTestGroup() group3.Name += "_3" group3, _, err = httpdtest.AddGroup(group3, http.StatusCreated) assert.NoError(t, err) user := getTestUser() user.UploadBandwidth = 32 user.DownloadBandwidth = 64 user.UploadDataTransfer = 1000 user.DownloadDataTransfer = 2000 user.UID = 1000 user.AdditionalInfo = "info" user.Description = "user dsc" user.Email = "test@test.com" user.Filters.AdditionalEmails = []string{"example1@test.com", "example2@test.com"} mappedDir := filepath.Join(os.TempDir(), "mapped") folderName := filepath.Base(mappedDir) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedDir, } folderAsJSON, err := json.Marshal(f) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("email", user.Email) form.Set("additional_emails[0][additional_email]", user.Filters.AdditionalEmails[0]) form.Set("additional_emails[1][additional_email]", user.Filters.AdditionalEmails[1]) form.Set("home_dir", user.HomeDir) form.Set("osfs_read_buffer_size", "2") form.Set("osfs_write_buffer_size", "3") form.Set("password", user.Password) form.Set("primary_group", group1.Name) form.Set("secondary_groups", group2.Name) form.Set("membership_groups", group3.Name) form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "") form.Set("permissions", "*") form.Set("directory_permissions[0][sub_perm_path]", "/subdir") form.Set("directory_permissions[0][sub_perm_permissions][]", "list") form.Add("directory_permissions[0][sub_perm_permissions][]", "download") form.Set("virtual_folders[0][vfolder_path]", " /vdir") form.Set("virtual_folders[0][vfolder_name]", folderName) form.Set("virtual_folders[0][vfolder_quota_files]", "2") form.Set("virtual_folders[0][vfolder_quota_size]", "1024") form.Set("directory_patterns[0][pattern_path]", "/dir2") form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") form.Set("directory_patterns[0][pattern_type]", "allowed") form.Set("directory_patterns[0][pattern_policy]", "1") form.Set("directory_patterns[1][pattern_path]", "/dir1") form.Set("directory_patterns[1][patterns]", "*.png") form.Set("directory_patterns[1][pattern_type]", "allowed") form.Set("directory_patterns[2][pattern_path]", "/dir1") form.Set("directory_patterns[2][patterns]", "*.zip") form.Set("directory_patterns[2][pattern_type]", "denied") form.Set("directory_patterns[3][pattern_path]", "/dir3") form.Set("directory_patterns[3][patterns]", "*.rar") form.Set("directory_patterns[3][pattern_type]", "denied") form.Set("directory_patterns[4][pattern_path]", "/dir2") form.Set("directory_patterns[4][patterns]", "*.mkv") form.Set("directory_patterns[4][pattern_type]", "denied") form.Set("access_time_restrictions[0][access_time_day_of_week]", "2") form.Set("access_time_restrictions[0][access_time_start]", "10") // invalid and no end, ignored form.Set("access_time_restrictions[1][access_time_day_of_week]", "3") form.Set("access_time_restrictions[1][access_time_start]", "12:00") form.Set("access_time_restrictions[1][access_time_end]", "14:09") form.Set("additional_info", user.AdditionalInfo) form.Set("description", user.Description) form.Add("hooks", "external_auth_disabled") form.Set("disable_fs_checks", "checked") form.Set("total_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("start_directory", "start/dir") form.Set("require_password_change", "1") b, contentType, _ := getMultipartFormData(form, "", "") // test invalid url escape req, _ = http.NewRequest(http.MethodPost, webUserPath+"?a=%2", &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("public_keys", testPubKey) form.Add("public_keys", testPubKey1) form.Set("uid", strconv.FormatInt(int64(user.UID), 10)) form.Set("gid", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid gid req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("gid", "0") form.Set("max_sessions", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid max sessions req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("max_sessions", "0") form.Set("quota_size", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid quota size req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("quota_size", "0") form.Set("quota_files", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid quota files req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("quota_files", "0") form.Set("upload_bandwidth", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid upload bandwidth req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("upload_bandwidth", strconv.FormatInt(user.UploadBandwidth, 10)) form.Set("download_bandwidth", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid download bandwidth req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("download_bandwidth", strconv.FormatInt(user.DownloadBandwidth, 10)) form.Set("upload_data_transfer", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid upload data transfer req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("upload_data_transfer", strconv.FormatInt(user.UploadDataTransfer, 10)) form.Set("download_data_transfer", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid download data transfer req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("download_data_transfer", strconv.FormatInt(user.DownloadDataTransfer, 10)) form.Set("total_data_transfer", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid total data transfer req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("total_data_transfer", strconv.FormatInt(user.TotalDataTransfer, 10)) form.Set("status", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid status req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "123") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid expiration date req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("expiration_date", "") form.Set("allowed_ip", "invalid,ip") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid allowed_ip req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("allowed_ip", "") form.Set("denied_ip", "192.168.1.2") // it should be 192.168.1.2/32 b, contentType, _ = getMultipartFormData(form, "", "") // test invalid denied_ip req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("denied_ip", "") // test invalid max file upload size form.Set("max_upload_file_size", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("max_upload_file_size", "1 KB") // test invalid default shares expiration form.Set("default_shares_expiration", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("default_shares_expiration", "10") // test invalid max shares expiration form.Set("max_shares_expiration", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("max_shares_expiration", "30") // test invalid password expiration form.Set("password_expiration", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("password_expiration", "90") // test invalid password strength form.Set("password_strength", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("password_strength", "60") // test invalid tls username form.Set("tls_username", "username") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("tls_username", string(sdk.TLSUsernameNone)) // invalid upload bandwidth source form.Set("src_bandwidth_limits[0][bandwidth_limit_sources]", "192.168.1.0/24, 192.168.2.0/25") form.Set("src_bandwidth_limits[0][upload_bandwidth_source]", "a") form.Set("src_bandwidth_limits[0][download_bandwidth_source]", "0") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) // invalid download bandwidth source form.Set("src_bandwidth_limits[0][upload_bandwidth_source]", "256") form.Set("src_bandwidth_limits[0][download_bandwidth_source]", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("src_bandwidth_limits[0][download_bandwidth_source]", "512") form.Set("src_bandwidth_limits[1][download_bandwidth_source]", "1024") form.Set("src_bandwidth_limits[1][bandwidth_limit_sources]", "1.1.1") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorSourceBWLimitInvalid) form.Set("src_bandwidth_limits[1][bandwidth_limit_sources]", "127.0.0.1/32") form.Set("src_bandwidth_limits[1][upload_bandwidth_source]", "-1") // invalid external auth cache size form.Set("external_auth_cache_time", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("external_auth_cache_time", "0") form.Set(csrfFormToken, "invalid form token") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) dbUser, err := dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) // the user already exists, was created with the above request b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) newUser := dataprovider.User{} err = render.DecodeJSON(rr.Body, &newUser) assert.NoError(t, err) assert.Equal(t, user.UID, newUser.UID) assert.Equal(t, 2, newUser.FsConfig.OSConfig.ReadBufferSize) assert.Equal(t, 3, newUser.FsConfig.OSConfig.WriteBufferSize) assert.Equal(t, user.UploadBandwidth, newUser.UploadBandwidth) assert.Equal(t, user.DownloadBandwidth, newUser.DownloadBandwidth) assert.Equal(t, user.UploadDataTransfer, newUser.UploadDataTransfer) assert.Equal(t, user.DownloadDataTransfer, newUser.DownloadDataTransfer) assert.Equal(t, user.TotalDataTransfer, newUser.TotalDataTransfer) assert.Equal(t, int64(1000), newUser.Filters.MaxUploadFileSize) assert.Equal(t, user.AdditionalInfo, newUser.AdditionalInfo) assert.Equal(t, user.Description, newUser.Description) assert.True(t, newUser.Filters.Hooks.ExternalAuthDisabled) assert.False(t, newUser.Filters.Hooks.PreLoginDisabled) assert.False(t, newUser.Filters.Hooks.CheckPasswordDisabled) assert.True(t, newUser.Filters.DisableFsChecks) assert.False(t, newUser.Filters.AllowAPIKeyAuth) assert.Equal(t, user.Email, newUser.Email) assert.Equal(t, len(user.Filters.AdditionalEmails), len(newUser.Filters.AdditionalEmails)) assert.Equal(t, "/start/dir", newUser.Filters.StartDirectory) assert.Equal(t, 0, newUser.Filters.FTPSecurity) assert.Equal(t, 10, newUser.Filters.DefaultSharesExpiration) assert.Equal(t, 30, newUser.Filters.MaxSharesExpiration) assert.Equal(t, 90, newUser.Filters.PasswordExpiration) assert.Equal(t, 60, newUser.Filters.PasswordStrength) assert.Greater(t, newUser.LastPasswordChange, int64(0)) assert.True(t, newUser.Filters.RequirePasswordChange) assert.True(t, slices.Contains(newUser.PublicKeys, testPubKey)) if val, ok := newUser.Permissions["/subdir"]; ok { assert.True(t, slices.Contains(val, dataprovider.PermListItems)) assert.True(t, slices.Contains(val, dataprovider.PermDownload)) } else { assert.Fail(t, "user permissions must contain /somedir", "actual: %v", newUser.Permissions) } assert.Len(t, newUser.PublicKeys, 2) assert.Len(t, newUser.VirtualFolders, 1) for _, v := range newUser.VirtualFolders { assert.Equal(t, v.VirtualPath, "/vdir") assert.Equal(t, v.Name, folderName) assert.Equal(t, v.MappedPath, mappedDir) assert.Equal(t, v.QuotaFiles, 2) assert.Equal(t, v.QuotaSize, int64(1024)) } assert.Len(t, newUser.Filters.FilePatterns, 3) for _, filter := range newUser.Filters.FilePatterns { switch filter.Path { case "/dir1": assert.Len(t, filter.DeniedPatterns, 1) assert.Len(t, filter.AllowedPatterns, 1) assert.True(t, slices.Contains(filter.AllowedPatterns, "*.png")) assert.True(t, slices.Contains(filter.DeniedPatterns, "*.zip")) assert.Equal(t, sdk.DenyPolicyDefault, filter.DenyPolicy) case "/dir2": assert.Len(t, filter.DeniedPatterns, 1) assert.Len(t, filter.AllowedPatterns, 2) assert.True(t, slices.Contains(filter.AllowedPatterns, "*.jpg")) assert.True(t, slices.Contains(filter.AllowedPatterns, "*.png")) assert.True(t, slices.Contains(filter.DeniedPatterns, "*.mkv")) assert.Equal(t, sdk.DenyPolicyHide, filter.DenyPolicy) case "/dir3": assert.Len(t, filter.DeniedPatterns, 1) assert.Len(t, filter.AllowedPatterns, 0) assert.True(t, slices.Contains(filter.DeniedPatterns, "*.rar")) assert.Equal(t, sdk.DenyPolicyDefault, filter.DenyPolicy) } } if assert.Len(t, newUser.Filters.BandwidthLimits, 2) { for _, bwLimit := range newUser.Filters.BandwidthLimits { if len(bwLimit.Sources) == 2 { assert.Equal(t, "192.168.1.0/24", bwLimit.Sources[0]) assert.Equal(t, "192.168.2.0/25", bwLimit.Sources[1]) assert.Equal(t, int64(256), bwLimit.UploadBandwidth) assert.Equal(t, int64(512), bwLimit.DownloadBandwidth) } else { assert.Equal(t, []string{"127.0.0.1/32"}, bwLimit.Sources) assert.Equal(t, int64(0), bwLimit.UploadBandwidth) assert.Equal(t, int64(1024), bwLimit.DownloadBandwidth) } } } if assert.Len(t, newUser.Filters.AccessTime, 1) { assert.Equal(t, 3, newUser.Filters.AccessTime[0].DayOfWeek) assert.Equal(t, "12:00", newUser.Filters.AccessTime[0].From) assert.Equal(t, "14:09", newUser.Filters.AccessTime[0].To) } assert.Len(t, newUser.Groups, 3) assert.Equal(t, sdk.TLSUsernameNone, newUser.Filters.TLSUsername) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, newUser.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveGroup(group1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group3, http.StatusOK) assert.NoError(t, err) } func TestWebUserUpdateMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() user.Filters.BandwidthLimits = []sdk.BandwidthLimit{ { Sources: []string{"10.8.0.0/16", "192.168.1.0/25"}, UploadBandwidth: 256, DownloadBandwidth: 512, }, } user.TotalDataTransfer = 4000 userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) lastPwdChange := user.LastPasswordChange assert.Greater(t, lastPwdChange, int64(0)) // add TOTP config configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) userToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) userTOTPConfig := dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH, common.ProtocolFTP}, } asJSON, err := json.Marshal(userTOTPConfig) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, userToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, userToken) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, userToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.TOTPConfig.Enabled) assert.Equal(t, int64(4000), user.TotalDataTransfer) assert.Equal(t, lastPwdChange, user.LastPasswordChange) if assert.Len(t, user.Filters.BandwidthLimits, 1) { if assert.Len(t, user.Filters.BandwidthLimits[0].Sources, 2) { assert.Equal(t, "10.8.0.0/16", user.Filters.BandwidthLimits[0].Sources[0]) assert.Equal(t, "192.168.1.0/25", user.Filters.BandwidthLimits[0].Sources[1]) } assert.Equal(t, int64(256), user.Filters.BandwidthLimits[0].UploadBandwidth) assert.Equal(t, int64(512), user.Filters.BandwidthLimits[0].DownloadBandwidth) } dbUser, err := dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) user.MaxSessions = 1 user.QuotaFiles = 2 user.QuotaSize = 1000 * 1000 * 1000 user.GID = 1000 user.Filters.AllowAPIKeyAuth = true user.AdditionalInfo = "new additional info" user.Email = "user@example.com" form := make(url.Values) form.Set("username", user.Username) form.Set("email", user.Email) form.Set("password", "") form.Set("public_keys[0][public_key]", testPubKey) form.Set("tls_certs[0][tls_cert]", httpsCert) form.Set("home_dir", user.HomeDir) form.Set("uid", "0") form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) form.Set("quota_size", "1 GB") form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("directory_permissions[0][sub_perm_path]", "/otherdir") form.Set("directory_permissions[0][sub_perm_permissions][]", "list") form.Add("directory_permissions[0][sub_perm_permissions][]", "upload") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") form.Set("allowed_ip", " 192.168.1.3/32, 192.168.2.0/24 ") form.Set("denied_ip", " 10.0.0.2/32 ") form.Set("directory_patterns[0][pattern_path]", "/dir1") form.Set("directory_patterns[0][patterns]", "*.zip") form.Set("directory_patterns[0][pattern_type]", "denied") form.Set("denied_login_methods", dataprovider.SSHLoginMethodKeyboardInteractive) form.Set("denied_protocols", common.ProtocolFTP) form.Set("max_upload_file_size", "100") form.Set("default_shares_expiration", "30") form.Set("max_shares_expiration", "60") form.Set("password_expiration", "60") form.Set("password_strength", "40") form.Set("disconnect", "1") form.Set("additional_info", user.AdditionalInfo) form.Set("description", user.Description) form.Set("tls_username", string(sdk.TLSUsernameCN)) form.Set("allow_api_key_auth", "1") form.Set("require_password_change", "1") form.Set("external_auth_cache_time", "120") b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) csrfToken, err = getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) dbUser, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.Empty(t, dbUser.Password) assert.False(t, dbUser.IsPasswordHashed()) form.Set("password", defaultPassword) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) dbUser, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) prevPwd := dbUser.Password form.Set("password", redactedSecret) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) dbUser, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) assert.Equal(t, prevPwd, dbUser.Password) assert.True(t, dbUser.Filters.TOTPConfig.Enabled) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updateUser dataprovider.User err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, user.Email, updateUser.Email) assert.Equal(t, user.HomeDir, updateUser.HomeDir) assert.Equal(t, user.MaxSessions, updateUser.MaxSessions) assert.Equal(t, user.QuotaFiles, updateUser.QuotaFiles) assert.Equal(t, user.QuotaSize, updateUser.QuotaSize) assert.Equal(t, user.UID, updateUser.UID) assert.Equal(t, user.GID, updateUser.GID) assert.Equal(t, user.AdditionalInfo, updateUser.AdditionalInfo) assert.Equal(t, user.Description, updateUser.Description) assert.Equal(t, int64(100), updateUser.Filters.MaxUploadFileSize) assert.Equal(t, sdk.TLSUsernameCN, updateUser.Filters.TLSUsername) assert.True(t, updateUser.Filters.AllowAPIKeyAuth) assert.True(t, updateUser.Filters.TOTPConfig.Enabled) assert.Equal(t, int64(0), updateUser.TotalDataTransfer) assert.Equal(t, int64(0), updateUser.DownloadDataTransfer) assert.Equal(t, int64(0), updateUser.UploadDataTransfer) assert.Equal(t, int64(0), updateUser.Filters.ExternalAuthCacheTime) assert.Equal(t, 30, updateUser.Filters.DefaultSharesExpiration) assert.Equal(t, 60, updateUser.Filters.MaxSharesExpiration) assert.Equal(t, 60, updateUser.Filters.PasswordExpiration) assert.Equal(t, 40, updateUser.Filters.PasswordStrength) assert.True(t, updateUser.Filters.RequirePasswordChange) if val, ok := updateUser.Permissions["/otherdir"]; ok { assert.True(t, slices.Contains(val, dataprovider.PermListItems)) assert.True(t, slices.Contains(val, dataprovider.PermUpload)) } else { assert.Fail(t, "user permissions must contains /otherdir", "actual: %v", updateUser.Permissions) } assert.True(t, slices.Contains(updateUser.Filters.AllowedIP, "192.168.1.3/32")) assert.True(t, slices.Contains(updateUser.Filters.DeniedIP, "10.0.0.2/32")) assert.True(t, slices.Contains(updateUser.Filters.DeniedLoginMethods, dataprovider.SSHLoginMethodKeyboardInteractive)) assert.True(t, slices.Contains(updateUser.Filters.DeniedProtocols, common.ProtocolFTP)) assert.True(t, slices.Contains(updateUser.Filters.FilePatterns[0].DeniedPatterns, "*.zip")) assert.Len(t, updateUser.Filters.BandwidthLimits, 0) assert.Len(t, updateUser.Filters.TLSCerts, 1) req, err = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) assert.NoError(t, err) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestRenderFolderTemplateMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webTemplateFolder, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) folder := vfs.BaseVirtualFolder{ Name: "templatefolder", MappedPath: filepath.Join(os.TempDir(), "mapped"), Description: "template folder desc", } folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webTemplateFolder+fmt.Sprintf("?from=%v", folder.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webTemplateFolder+"?from=unknown-folder", nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) } func TestRenderUserTemplateMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webTemplateUser, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webTemplateUser+fmt.Sprintf("?from=%v", user.Username), nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webTemplateUser+"?from=unknown", nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestUserSaveFromTemplateMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) user1 := "u1" user2 := "u2" form := make(url.Values) form.Set("username", "") form.Set("home_dir", filepath.Join(os.TempDir(), "%username%")) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("uid", "0") form.Set("gid", "0") form.Set("max_sessions", "0") form.Set("quota_size", "0") form.Set("quota_files", "0") form.Set("permissions", "*") form.Set("status", "1") form.Set("expiration_date", "") form.Set("fs_provider", "0") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") form.Set("external_auth_cache_time", "0") form.Add("template_users[0][tpl_username]", user1) form.Add("template_users[0][tpl_password]", "password1") form.Add("template_users[0][tpl_public_keys]", " ") form.Add("template_users[1][tpl_username]", user2) form.Add("template_users[1][tpl_public_keys]", testPubKey) b, contentType, _ := getMultipartFormData(form, "", "") req, _ := http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) u1, _, err := httpdtest.GetUserByUsername(user1, http.StatusOK) assert.NoError(t, err) assert.False(t, u1.Filters.RequirePasswordChange) u2, _, err := httpdtest.GetUserByUsername(user2, http.StatusOK) assert.NoError(t, err) assert.False(t, u2.Filters.RequirePasswordChange) _, err = httpdtest.RemoveUser(u1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(u2, http.StatusOK) assert.NoError(t, err) form.Add("tpl_require_password_change", "checked") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) u1, _, err = httpdtest.GetUserByUsername(user1, http.StatusOK) assert.NoError(t, err) assert.True(t, u1.Filters.RequirePasswordChange) u2, _, err = httpdtest.GetUserByUsername(user2, http.StatusOK) assert.NoError(t, err) assert.True(t, u2.Filters.RequirePasswordChange) _, err = httpdtest.RemoveUser(u1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(u2, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) b, contentType, _ = getMultipartFormData(form, "", "") req, err = http.NewRequest(http.MethodPost, webTemplateUser, &b) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestUserTemplateErrors(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) user := getTestUser() user.FsConfig.Provider = sdk.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test" user.FsConfig.S3Config.Region = "eu-central-1" user.FsConfig.S3Config.AccessKey = "%username%" user.FsConfig.S3Config.KeyPrefix = "somedir/subdir/" user.FsConfig.S3Config.UploadPartSize = 5 user.FsConfig.S3Config.UploadConcurrency = 4 user.FsConfig.S3Config.DownloadPartSize = 6 user.FsConfig.S3Config.DownloadConcurrency = 3 form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("home_dir", filepath.Join(os.TempDir(), "%username%")) form.Set("uid", "0") form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") form.Set("allowed_ip", "") form.Set("denied_ip", "") form.Set("fs_provider", "1") form.Set("s3_bucket", user.FsConfig.S3Config.Bucket) form.Set("s3_region", user.FsConfig.S3Config.Region) form.Set("s3_access_key", "%username%") form.Set("s3_access_secret", "%password%") form.Set("s3_sse_customer_key", "%password%") form.Set("s3_key_prefix", "base/%username%") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") form.Add("hooks", "external_auth_disabled") form.Add("hooks", "check_password_disabled") form.Set("disable_fs_checks", "checked") form.Set("s3_download_part_max_time", "0") form.Set("s3_upload_part_max_time", "0") // test invalid s3_upload_part_size form.Set("s3_upload_part_size", "a") form.Set("form_action", "export_from_template") b, contentType, _ := getMultipartFormData(form, "", "") req, _ := http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) form.Set("s3_upload_part_size", strconv.FormatInt(user.FsConfig.S3Config.UploadPartSize, 10)) form.Set("s3_upload_concurrency", strconv.Itoa(user.FsConfig.S3Config.UploadConcurrency)) form.Set("s3_download_part_size", strconv.FormatInt(user.FsConfig.S3Config.DownloadPartSize, 10)) form.Set("s3_download_concurrency", strconv.Itoa(user.FsConfig.S3Config.DownloadConcurrency)) // no user defined b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorUserTemplate) form.Set("template_users[0][tpl_username]", "user1") form.Set("template_users[0][tpl_password]", "password1") form.Set("template_users[0][tpl_public_keys]", "invalid-pkey") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) require.Contains(t, rr.Body.String(), util.I18nErrorPubKeyInvalid) form.Set("template_users[0][tpl_username]", " ") form.Set("template_users[0][tpl_password]", "pwd") form.Set("template_users[0][tpl_public_keys]", testPubKey) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) require.Contains(t, rr.Body.String(), util.I18nErrorUserTemplate) } func TestUserTemplateRoleAndPermissions(t *testing.T) { r1 := getTestRole() r2 := getTestRole() r2.Name += "_mod" role1, resp, err := httpdtest.AddRole(r1, http.StatusCreated) assert.NoError(t, err, string(resp)) role2, resp, err := httpdtest.AddRole(r2, http.StatusCreated) assert.NoError(t, err, string(resp)) admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin.Role = role1.Name admin.Permissions = []string{dataprovider.PermAdminManageFolders, dataprovider.PermAdminChangeUsers, dataprovider.PermAdminViewUsers} admin, _, err = httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) token, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webTemplateUser, nil) setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) user1 := "u1" user2 := "u2" form := make(url.Values) form.Set("username", "") form.Set("role", role2.Name) form.Set("home_dir", filepath.Join(os.TempDir(), "%username%")) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("uid", "0") form.Set("gid", "0") form.Set("max_sessions", "0") form.Set("quota_size", "0") form.Set("quota_files", "0") form.Set("permissions", "*") form.Set("status", "1") form.Set("expiration_date", "") form.Set("fs_provider", "0") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") form.Set("external_auth_cache_time", "0") form.Add("template_users[0][tpl_username]", user1) form.Add("template_users[0][tpl_password]", "password1") form.Add("template_users[0][tpl_public_keys]", " ") form.Add("template_users[1][tpl_username]", user2) form.Add("template_users[1][tpl_public_keys]", testPubKey) form.Set(csrfFormToken, csrfToken) b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // Add the required permissions admin.Permissions = append(admin.Permissions, dataprovider.PermAdminAddUsers) _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) token, err = getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, webTemplateUser, nil) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) csrfToken, err = getCSRFTokenFromInternalPageMock(webTemplateUser, token) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) u1, _, err := httpdtest.GetUserByUsername(user1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, admin.Role, u1.Role) u2, _, err := httpdtest.GetUserByUsername(user2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, admin.Role, u2.Role) _, err = httpdtest.RemoveUser(u1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(u2, http.StatusOK) assert.NoError(t, err) // Set an empty role form.Set("role", "") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) u1, _, err = httpdtest.GetUserByUsername(user1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, admin.Role, u1.Role) u2, _, err = httpdtest.GetUserByUsername(user2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, admin.Role, u2.Role) _, err = httpdtest.RemoveUser(u1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(u2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role2, http.StatusOK) assert.NoError(t, err) } func TestUserPlaceholders(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, token) assert.NoError(t, err) u := getTestUser() u.HomeDir = filepath.Join(os.TempDir(), "%username%_%password%") form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", u.Username) form.Set("home_dir", u.HomeDir) form.Set("password", u.Password) form.Set("status", strconv.Itoa(u.Status)) form.Set("expiration_date", "") form.Set("permissions", "*") form.Set("public_keys[0][public_key]", testPubKey) form.Set("public_keys[1][public_key]", testPubKey1) form.Set("uid", "0") form.Set("gid", "0") form.Set("max_sessions", "0") form.Set("quota_size", "0") form.Set("quota_files", "0") form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("total_data_transfer", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") b, contentType, _ := getMultipartFormData(form, "", "") req, _ := http.NewRequest(http.MethodPost, webUserPath, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, filepath.Join(os.TempDir(), fmt.Sprintf("%v_%v", defaultUsername, defaultPassword)), user.HomeDir) dbUser, err := dataprovider.UserExists(defaultUsername, "") assert.NoError(t, err) assert.True(t, dbUser.IsPasswordHashed()) hashedPwd := dbUser.Password form.Set("password", redactedSecret) b, contentType, _ = getMultipartFormData(form, "", "") req, err = http.NewRequest(http.MethodPost, path.Join(webUserPath, defaultUsername), &b) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, filepath.Join(os.TempDir(), defaultUsername+"_%password%"), user.HomeDir) // check that the password was unchanged dbUser, err = dataprovider.UserExists(defaultUsername, "") assert.NoError(t, err) assert.True(t, dbUser.IsPasswordHashed()) assert.Equal(t, hashedPwd, dbUser.Password) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestFolderPlaceholders(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, token) assert.NoError(t, err) folderName := "folderName" form := make(url.Values) form.Set("name", folderName) form.Set("mapped_path", filepath.Join(os.TempDir(), "%name%")) form.Set("description", "desc folder %name%") form.Set(csrfFormToken, csrfToken) b, contentType, _ := getMultipartFormData(form, "", "") req, err := http.NewRequest(http.MethodPost, webFolderPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) folderGet, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, filepath.Join(os.TempDir(), folderName), folderGet.MappedPath) assert.Equal(t, fmt.Sprintf("desc folder %v", folderName), folderGet.Description) form.Set("mapped_path", filepath.Join(os.TempDir(), "%name%_%name%")) b, contentType, _ = getMultipartFormData(form, "", "") req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) folderGet, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, filepath.Join(os.TempDir(), fmt.Sprintf("%v_%v", folderName, folderName)), folderGet.MappedPath) assert.Equal(t, fmt.Sprintf("desc folder %v", folderName), folderGet.Description) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) } func TestFolderSaveFromTemplateMock(t *testing.T) { folder1 := "f1" folder2 := "f2" token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) form := make(url.Values) form.Set("name", "name") form.Set("mapped_path", filepath.Join(os.TempDir(), "%name%")) form.Set("description", "desc folder %name%") form.Set("template_folders[0][tpl_foldername]", folder1) form.Set("template_folders[1][tpl_foldername]", folder2) form.Set(csrfFormToken, csrfToken) b, contentType, _ := getMultipartFormData(form, "", "") req, err := http.NewRequest(http.MethodPost, webTemplateFolder, &b) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) _, _, err = httpdtest.GetFolderByName(folder1, http.StatusOK) assert.NoError(t, err) _, _, err = httpdtest.GetFolderByName(folder2, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folder1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folder2}, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) b, contentType, _ = getMultipartFormData(form, "", "") req, err = http.NewRequest(http.MethodPost, webTemplateFolder, &b) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestFolderTemplateErrors(t *testing.T) { folderName := "vfolder-template" mappedPath := filepath.Join(os.TempDir(), "%name%mapped%name%path") token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) form := make(url.Values) form.Set("name", folderName) form.Set("mapped_path", mappedPath) form.Set("description", "desc folder %name%") form.Set("template_folders[0][tpl_foldername]", "folder1") form.Set("template_folders[1][tpl_foldername]", "folder2") form.Set("template_folders[2][tpl_foldername]", "folder3") form.Set("template_folders[3][tpl_foldername]", "folder1 ") form.Add("template_folders[3][tpl_foldername]", " ") b, contentType, _ := getMultipartFormData(form, "", "") req, _ := http.NewRequest(http.MethodPost, webTemplateFolder, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateFolder+"?param=p%C3%AO%GG", &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) form.Set("fs_provider", "1") form.Set("s3_bucket", "bucket") form.Set("s3_region", "us-east-1") form.Set("s3_access_key", "%name%") form.Set("s3_access_secret", "pwd%name%") form.Set("s3_sse_customer_key", "key%name%") form.Set("s3_key_prefix", "base/%name%") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("s3_upload_part_size", "5") form.Set("s3_upload_concurrency", "4") form.Set("s3_download_part_max_time", "0") form.Set("s3_upload_part_max_time", "0") form.Set("s3_download_part_size", "6") form.Set("s3_download_concurrency", "2") form.Set("template_folders[0][tpl_foldername]", " ") form.Set("template_folders[1][tpl_foldername]", "") form.Set("template_folders[2][tpl_foldername]", "") form.Set("template_folders[3][tpl_foldername]", " ") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorFolderTemplate) form.Set("template_folders[0][tpl_foldername]", "name") form.Set("mapped_path", "relative-path") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, &b) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidHomeDir) } func TestFolderTemplatePermission(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword admin.Permissions = []string{dataprovider.PermAdminChangeUsers, dataprovider.PermAdminAddUsers, dataprovider.PermAdminViewUsers} admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) assert.NoError(t, err) // no permission to view or add folders from templates token, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateUser, token) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webTemplateFolder, nil) assert.NoError(t, err) req.RequestURI = webTemplateFolder setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) form := make(url.Values) form.Set("name", "name") form.Set("mapped_path", filepath.Join(os.TempDir(), "%name%")) form.Set("description", "desc folder %name%") form.Set("template_folders[0][tpl_foldername]", "folder1") form.Set("template_folders[1][tpl_foldername]", "folder2") form.Set(csrfFormToken, csrfToken) b, contentType, _ := getMultipartFormData(form, "", "") req, err = http.NewRequest(http.MethodPost, webTemplateFolder, &b) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) admin.Permissions = append(admin.Permissions, dataprovider.PermAdminManageFolders) _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) token, err = getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) _, err = getCSRFTokenFromInternalPageMock(webTemplateUser, token) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webTemplateFolder, nil) assert.NoError(t, err) req.RequestURI = webTemplateFolder setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestWebUserS3Mock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) lastPwdChange := user.LastPasswordChange assert.Greater(t, lastPwdChange, int64(0)) user.FsConfig.Provider = sdk.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test" user.FsConfig.S3Config.Region = "eu-west-1" user.FsConfig.S3Config.AccessKey = "access-key" user.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("access-secret") user.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("enc-key") user.FsConfig.S3Config.RoleARN = "arn:aws:iam::123456789012:user/Development/product_1234/*" user.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000/path?a=b" user.FsConfig.S3Config.StorageClass = "Standard" user.FsConfig.S3Config.KeyPrefix = "somedir/subdir/" user.FsConfig.S3Config.UploadPartSize = 5 user.FsConfig.S3Config.UploadConcurrency = 4 user.FsConfig.S3Config.DownloadPartMaxTime = 60 user.FsConfig.S3Config.UploadPartMaxTime = 120 user.FsConfig.S3Config.DownloadPartSize = 6 user.FsConfig.S3Config.DownloadConcurrency = 3 user.FsConfig.S3Config.ForcePathStyle = true user.FsConfig.S3Config.SkipTLSVerify = true user.FsConfig.S3Config.ACL = "public-read" user.Description = "s3 tèst user" form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("password", redactedSecret) form.Set("home_dir", user.HomeDir) form.Set("uid", "0") form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") form.Set("allowed_ip", "") form.Set("denied_ip", "") form.Set("fs_provider", "1") form.Set("s3_bucket", user.FsConfig.S3Config.Bucket) form.Set("s3_region", user.FsConfig.S3Config.Region) form.Set("s3_access_key", user.FsConfig.S3Config.AccessKey) form.Set("s3_access_secret", user.FsConfig.S3Config.AccessSecret.GetPayload()) form.Set("s3_sse_customer_key", user.FsConfig.S3Config.SSECustomerKey.GetPayload()) form.Set("s3_role_arn", user.FsConfig.S3Config.RoleARN) form.Set("s3_storage_class", user.FsConfig.S3Config.StorageClass) form.Set("s3_acl", user.FsConfig.S3Config.ACL) form.Set("s3_endpoint", user.FsConfig.S3Config.Endpoint) form.Set("s3_key_prefix", user.FsConfig.S3Config.KeyPrefix) form.Set("directory_patterns[0][pattern_path]", "/dir1") form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") form.Set("directory_patterns[0][pattern_type]", "allowed") form.Set("directory_patterns[0][pattern_policy]", "0") form.Set("directory_patterns[1][pattern_path]", "/dir2") form.Set("directory_patterns[1][patterns]", "*.zip") form.Set("directory_patterns[1][pattern_type]", "denied") form.Set("directory_patterns[1][pattern_policy]", "1") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") form.Set("ftp_security", "1") form.Set("s3_force_path_style", "checked") form.Set("s3_skip_tls_verify", "checked") form.Set("description", user.Description) form.Add("hooks", "pre_login_disabled") form.Add("allow_api_key_auth", "1") // test invalid s3_upload_part_size form.Set("s3_upload_part_size", "a") b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test invalid s3_upload_concurrency form.Set("s3_upload_part_size", strconv.FormatInt(user.FsConfig.S3Config.UploadPartSize, 10)) form.Set("s3_upload_concurrency", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test invalid s3_download_part_size form.Set("s3_upload_concurrency", strconv.Itoa(user.FsConfig.S3Config.UploadConcurrency)) form.Set("s3_download_part_size", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test invalid s3_download_concurrency form.Set("s3_download_part_size", strconv.FormatInt(user.FsConfig.S3Config.DownloadPartSize, 10)) form.Set("s3_download_concurrency", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test invalid s3_download_part_max_time form.Set("s3_download_concurrency", strconv.Itoa(user.FsConfig.S3Config.DownloadConcurrency)) form.Set("s3_download_part_max_time", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test invalid s3_upload_part_max_time form.Set("s3_download_part_max_time", strconv.Itoa(user.FsConfig.S3Config.DownloadPartMaxTime)) form.Set("s3_upload_part_max_time", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now add the user form.Set("s3_upload_part_max_time", strconv.Itoa(user.FsConfig.S3Config.UploadPartMaxTime)) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updateUser dataprovider.User err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) assert.Equal(t, updateUser.FsConfig.S3Config.Bucket, user.FsConfig.S3Config.Bucket) assert.Equal(t, updateUser.FsConfig.S3Config.Region, user.FsConfig.S3Config.Region) assert.Equal(t, updateUser.FsConfig.S3Config.AccessKey, user.FsConfig.S3Config.AccessKey) assert.Equal(t, updateUser.FsConfig.S3Config.RoleARN, user.FsConfig.S3Config.RoleARN) assert.Equal(t, updateUser.FsConfig.S3Config.StorageClass, user.FsConfig.S3Config.StorageClass) assert.Equal(t, updateUser.FsConfig.S3Config.ACL, user.FsConfig.S3Config.ACL) assert.Equal(t, updateUser.FsConfig.S3Config.Endpoint, user.FsConfig.S3Config.Endpoint) assert.Equal(t, updateUser.FsConfig.S3Config.KeyPrefix, user.FsConfig.S3Config.KeyPrefix) assert.Equal(t, updateUser.FsConfig.S3Config.UploadPartSize, user.FsConfig.S3Config.UploadPartSize) assert.Equal(t, updateUser.FsConfig.S3Config.UploadConcurrency, user.FsConfig.S3Config.UploadConcurrency) assert.Equal(t, updateUser.FsConfig.S3Config.DownloadPartMaxTime, user.FsConfig.S3Config.DownloadPartMaxTime) assert.Equal(t, updateUser.FsConfig.S3Config.UploadPartMaxTime, user.FsConfig.S3Config.UploadPartMaxTime) assert.Equal(t, updateUser.FsConfig.S3Config.DownloadPartSize, user.FsConfig.S3Config.DownloadPartSize) assert.Equal(t, updateUser.FsConfig.S3Config.DownloadConcurrency, user.FsConfig.S3Config.DownloadConcurrency) assert.Equal(t, lastPwdChange, updateUser.LastPasswordChange) assert.True(t, updateUser.FsConfig.S3Config.ForcePathStyle) assert.True(t, updateUser.FsConfig.S3Config.SkipTLSVerify) if assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) { for _, filter := range updateUser.Filters.FilePatterns { switch filter.Path { case "/dir1": assert.Equal(t, sdk.DenyPolicyDefault, filter.DenyPolicy) case "/dir2": assert.Equal(t, sdk.DenyPolicyHide, filter.DenyPolicy) } } } assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.S3Config.AccessSecret.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.S3Config.AccessSecret.GetPayload()) assert.Empty(t, updateUser.FsConfig.S3Config.AccessSecret.GetKey()) assert.Empty(t, updateUser.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.S3Config.SSECustomerKey.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.S3Config.SSECustomerKey.GetPayload()) assert.Empty(t, updateUser.FsConfig.S3Config.SSECustomerKey.GetKey()) assert.Empty(t, updateUser.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) assert.Equal(t, user.Description, updateUser.Description) assert.True(t, updateUser.Filters.Hooks.PreLoginDisabled) assert.False(t, updateUser.Filters.Hooks.ExternalAuthDisabled) assert.False(t, updateUser.Filters.Hooks.CheckPasswordDisabled) assert.False(t, updateUser.Filters.DisableFsChecks) assert.True(t, updateUser.Filters.AllowAPIKeyAuth) assert.Equal(t, 1, updateUser.Filters.FTPSecurity) // now check that a redacted password is not saved form.Set("s3_access_secret", redactedSecret) form.Set("s3_sse_customer_key", redactedSecret) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var lastUpdatedUser dataprovider.User err = render.DecodeJSON(rr.Body, &lastUpdatedUser) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.S3Config.AccessSecret.GetStatus()) assert.Equal(t, updateUser.FsConfig.S3Config.AccessSecret.GetPayload(), lastUpdatedUser.FsConfig.S3Config.AccessSecret.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.S3Config.AccessSecret.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.S3Config.SSECustomerKey.GetStatus()) assert.Equal(t, updateUser.FsConfig.S3Config.SSECustomerKey.GetPayload(), lastUpdatedUser.FsConfig.S3Config.SSECustomerKey.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.S3Config.SSECustomerKey.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) assert.Equal(t, lastPwdChange, lastUpdatedUser.LastPasswordChange) // now clear credentials form.Set("s3_access_key", "") form.Set("s3_access_secret", "") form.Set("s3_sse_customer_key", "") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var userGet dataprovider.User err = render.DecodeJSON(rr.Body, &userGet) assert.NoError(t, err) assert.Nil(t, userGet.FsConfig.S3Config.AccessSecret) assert.Nil(t, userGet.FsConfig.S3Config.SSECustomerKey) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestWebUserGCSMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, err := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) assert.NoError(t, err) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) credentialsFilePath := filepath.Join(os.TempDir(), "gcs.json") err = createTestFile(credentialsFilePath, 0) assert.NoError(t, err) user.FsConfig.Provider = sdk.GCSFilesystemProvider user.FsConfig.GCSConfig.Bucket = "test" user.FsConfig.GCSConfig.KeyPrefix = "somedir/subdir/" user.FsConfig.GCSConfig.StorageClass = "standard" user.FsConfig.GCSConfig.ACL = "publicReadWrite" user.FsConfig.GCSConfig.UploadPartSize = 16 user.FsConfig.GCSConfig.UploadPartMaxTime = 32 form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("password", redactedSecret) form.Set("home_dir", user.HomeDir) form.Set("uid", "0") form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") form.Set("allowed_ip", "") form.Set("denied_ip", "") form.Set("fs_provider", "2") form.Set("gcs_bucket", user.FsConfig.GCSConfig.Bucket) form.Set("gcs_storage_class", user.FsConfig.GCSConfig.StorageClass) form.Set("gcs_acl", user.FsConfig.GCSConfig.ACL) form.Set("gcs_key_prefix", user.FsConfig.GCSConfig.KeyPrefix) form.Set("gcs_upload_part_size", strconv.FormatInt(user.FsConfig.GCSConfig.UploadPartSize, 10)) form.Set("gcs_upload_part_max_time", strconv.FormatInt(int64(user.FsConfig.GCSConfig.UploadPartMaxTime), 10)) form.Set("directory_patterns[0][pattern_path]", "/dir1") form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") form.Set("directory_patterns[0][pattern_type]", "allowed") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") form.Set("ftp_security", "1") b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) b, contentType, _ = getMultipartFormData(form, "gcs_credential_file", credentialsFilePath) req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = createTestFile(credentialsFilePath, 4096) assert.NoError(t, err) b, contentType, _ = getMultipartFormData(form, "gcs_credential_file", credentialsFilePath) req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updateUser dataprovider.User err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) assert.Equal(t, user.FsConfig.Provider, updateUser.FsConfig.Provider) assert.Equal(t, user.FsConfig.GCSConfig.Bucket, updateUser.FsConfig.GCSConfig.Bucket) assert.Equal(t, user.FsConfig.GCSConfig.StorageClass, updateUser.FsConfig.GCSConfig.StorageClass) assert.Equal(t, user.FsConfig.GCSConfig.ACL, updateUser.FsConfig.GCSConfig.ACL) assert.Equal(t, user.FsConfig.GCSConfig.KeyPrefix, updateUser.FsConfig.GCSConfig.KeyPrefix) assert.Equal(t, user.FsConfig.GCSConfig.UploadPartSize, updateUser.FsConfig.GCSConfig.UploadPartSize) assert.Equal(t, user.FsConfig.GCSConfig.UploadPartMaxTime, updateUser.FsConfig.GCSConfig.UploadPartMaxTime) if assert.Len(t, updateUser.Filters.FilePatterns, 1) { assert.Equal(t, "/dir1", updateUser.Filters.FilePatterns[0].Path) assert.Len(t, updateUser.Filters.FilePatterns[0].AllowedPatterns, 2) assert.Contains(t, updateUser.Filters.FilePatterns[0].AllowedPatterns, "*.png") assert.Contains(t, updateUser.Filters.FilePatterns[0].AllowedPatterns, "*.jpg") } assert.Equal(t, 1, updateUser.Filters.FTPSecurity) form.Set("gcs_auto_credentials", "on") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) updateUser = dataprovider.User{} err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, 1, updateUser.FsConfig.GCSConfig.AutomaticCredentials) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = os.Remove(credentialsFilePath) assert.NoError(t, err) } func TestWebUserHTTPFsMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, err := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) assert.NoError(t, err) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) user.FsConfig.Provider = sdk.HTTPFilesystemProvider user.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: "https://127.0.0.1:9999/api/v1", Username: defaultUsername, SkipTLSVerify: true, }, Password: kms.NewPlainSecret(defaultPassword), APIKey: kms.NewPlainSecret(defaultTokenAuthPass), } form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("password", redactedSecret) form.Set("home_dir", user.HomeDir) form.Set("uid", "0") form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") form.Set("allowed_ip", "") form.Set("denied_ip", "") form.Set("fs_provider", "6") form.Set("http_endpoint", user.FsConfig.HTTPConfig.Endpoint) form.Set("http_username", user.FsConfig.HTTPConfig.Username) form.Set("http_password", user.FsConfig.HTTPConfig.Password.GetPayload()) form.Set("http_api_key", user.FsConfig.HTTPConfig.APIKey.GetPayload()) form.Set("http_skip_tls_verify", "checked") form.Set("directory_patterns[0][pattern_path]", "/dir1") form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") form.Set("directory_patterns[0][pattern_type]", "allowed") form.Set("directory_patterns[1][pattern_path]", "/dir2") form.Set("directory_patterns[1][patterns]", "*.zip") form.Set("directory_patterns[1][pattern_type]", "denied") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") form.Set("http_equality_check_mode", "true") b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the updated user req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updateUser dataprovider.User err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) assert.Equal(t, user.FsConfig.HTTPConfig.Endpoint, updateUser.FsConfig.HTTPConfig.Endpoint) assert.Equal(t, user.FsConfig.HTTPConfig.Username, updateUser.FsConfig.HTTPConfig.Username) assert.Equal(t, user.FsConfig.HTTPConfig.SkipTLSVerify, updateUser.FsConfig.HTTPConfig.SkipTLSVerify) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.HTTPConfig.Password.GetPayload()) assert.Empty(t, updateUser.FsConfig.HTTPConfig.Password.GetKey()) assert.Empty(t, updateUser.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.HTTPConfig.APIKey.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.HTTPConfig.APIKey.GetPayload()) assert.Empty(t, updateUser.FsConfig.HTTPConfig.APIKey.GetKey()) assert.Empty(t, updateUser.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) assert.Equal(t, 1, updateUser.FsConfig.HTTPConfig.EqualityCheckMode) // now check that a redacted password is not saved form.Set("http_equality_check_mode", "") form.Set("http_password", " "+redactedSecret+" ") form.Set("http_api_key", redactedSecret) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var lastUpdatedUser dataprovider.User err = render.DecodeJSON(rr.Body, &lastUpdatedUser) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.HTTPConfig.Password.GetStatus()) assert.Equal(t, updateUser.FsConfig.HTTPConfig.Password.GetPayload(), lastUpdatedUser.FsConfig.HTTPConfig.Password.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.HTTPConfig.Password.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.HTTPConfig.APIKey.GetStatus()) assert.Equal(t, updateUser.FsConfig.HTTPConfig.APIKey.GetPayload(), lastUpdatedUser.FsConfig.HTTPConfig.APIKey.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.HTTPConfig.APIKey.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) assert.Equal(t, 0, lastUpdatedUser.FsConfig.HTTPConfig.EqualityCheckMode) req, err = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) assert.NoError(t, err) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestWebUserAzureBlobMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.Container = "container" user.FsConfig.AzBlobConfig.AccountName = "aname" user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("access-skey") user.FsConfig.AzBlobConfig.Endpoint = "http://127.0.0.1:9000/path?b=c" user.FsConfig.AzBlobConfig.KeyPrefix = "somedir/subdir/" user.FsConfig.AzBlobConfig.UploadPartSize = 5 user.FsConfig.AzBlobConfig.UploadConcurrency = 4 user.FsConfig.AzBlobConfig.DownloadPartSize = 3 user.FsConfig.AzBlobConfig.DownloadConcurrency = 6 user.FsConfig.AzBlobConfig.UseEmulator = true form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("password", redactedSecret) form.Set("home_dir", user.HomeDir) form.Set("uid", "0") form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") form.Set("allowed_ip", "") form.Set("denied_ip", "") form.Set("fs_provider", "3") form.Set("az_container", user.FsConfig.AzBlobConfig.Container) form.Set("az_account_name", user.FsConfig.AzBlobConfig.AccountName) form.Set("az_account_key", user.FsConfig.AzBlobConfig.AccountKey.GetPayload()) form.Set("az_endpoint", user.FsConfig.AzBlobConfig.Endpoint) form.Set("az_key_prefix", user.FsConfig.AzBlobConfig.KeyPrefix) form.Set("az_use_emulator", "checked") form.Set("directory_patterns[0][pattern_path]", "/dir1") form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") form.Set("directory_patterns[0][pattern_type]", "allowed") form.Set("directory_patterns[1][pattern_path]", "/dir2") form.Set("directory_patterns[1][patterns]", "*.zip") form.Set("directory_patterns[1][pattern_type]", "denied") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") // test invalid az_upload_part_size form.Set("az_upload_part_size", "a") b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test invalid az_upload_concurrency form.Set("az_upload_part_size", strconv.FormatInt(user.FsConfig.AzBlobConfig.UploadPartSize, 10)) form.Set("az_upload_concurrency", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test invalid az_download_part_size form.Set("az_upload_concurrency", strconv.Itoa(user.FsConfig.AzBlobConfig.UploadConcurrency)) form.Set("az_download_part_size", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // test invalid az_download_concurrency form.Set("az_download_part_size", strconv.FormatInt(user.FsConfig.AzBlobConfig.DownloadPartSize, 10)) form.Set("az_download_concurrency", "a") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now add the user form.Set("az_download_concurrency", strconv.Itoa(user.FsConfig.AzBlobConfig.DownloadConcurrency)) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updateUser dataprovider.User err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.Container, user.FsConfig.AzBlobConfig.Container) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.AccountName, user.FsConfig.AzBlobConfig.AccountName) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.Endpoint, user.FsConfig.AzBlobConfig.Endpoint) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.KeyPrefix, user.FsConfig.AzBlobConfig.KeyPrefix) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.UploadPartSize, user.FsConfig.AzBlobConfig.UploadPartSize) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.UploadConcurrency, user.FsConfig.AzBlobConfig.UploadConcurrency) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.DownloadPartSize, user.FsConfig.AzBlobConfig.DownloadPartSize) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.DownloadConcurrency, user.FsConfig.AzBlobConfig.DownloadConcurrency) assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.AzBlobConfig.AccountKey.GetPayload()) assert.Empty(t, updateUser.FsConfig.AzBlobConfig.AccountKey.GetKey()) assert.Empty(t, updateUser.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) // now check that a redacted password is not saved form.Set("az_account_key", redactedSecret+" ") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var lastUpdatedUser dataprovider.User err = render.DecodeJSON(rr.Body, &lastUpdatedUser) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.AzBlobConfig.AccountKey.GetStatus()) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.AccountKey.GetPayload(), lastUpdatedUser.FsConfig.AzBlobConfig.AccountKey.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.AzBlobConfig.AccountKey.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) // test SAS url user.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("sasurl") form.Set("az_account_name", "") form.Set("az_account_key", "") form.Set("az_container", "") form.Set("az_sas_url", user.FsConfig.AzBlobConfig.SASURL.GetPayload()) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) updateUser = dataprovider.User{} err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.AzBlobConfig.SASURL.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.AzBlobConfig.SASURL.GetPayload()) assert.Empty(t, updateUser.FsConfig.AzBlobConfig.SASURL.GetKey()) assert.Empty(t, updateUser.FsConfig.AzBlobConfig.SASURL.GetAdditionalData()) // now check that a redacted sas url is not saved form.Set("az_sas_url", redactedSecret) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) lastUpdatedUser = dataprovider.User{} err = render.DecodeJSON(rr.Body, &lastUpdatedUser) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.AzBlobConfig.SASURL.GetStatus()) assert.Equal(t, updateUser.FsConfig.AzBlobConfig.SASURL.GetPayload(), lastUpdatedUser.FsConfig.AzBlobConfig.SASURL.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.AzBlobConfig.SASURL.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.AzBlobConfig.SASURL.GetAdditionalData()) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestWebUserCryptMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) user.FsConfig.Provider = sdk.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("crypted passphrase") form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("password", redactedSecret) form.Set("home_dir", user.HomeDir) form.Set("uid", "0") form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") form.Set("allowed_ip", "") form.Set("denied_ip", "") form.Set("fs_provider", "4") form.Set("crypt_passphrase", "") form.Set("cryptfs_read_buffer_size", "1") form.Set("cryptfs_write_buffer_size", "2") form.Set("directory_patterns[0][pattern_path]", "/dir1") form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") form.Set("directory_patterns[0][pattern_type]", "allowed") form.Set("directory_patterns[1][pattern_path]", "/dir2") form.Set("directory_patterns[1][patterns]", "*.zip") form.Set("directory_patterns[1][pattern_type]", "denied") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") // passphrase cannot be empty b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("crypt_passphrase", user.FsConfig.CryptConfig.Passphrase.GetPayload()) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updateUser dataprovider.User err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, updateUser.FsConfig.CryptConfig.Passphrase.GetKey()) assert.Empty(t, updateUser.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Equal(t, 1, updateUser.FsConfig.CryptConfig.ReadBufferSize) assert.Equal(t, 2, updateUser.FsConfig.CryptConfig.WriteBufferSize) // now check that a redacted password is not saved form.Set("crypt_passphrase", redactedSecret+" ") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var lastUpdatedUser dataprovider.User err = render.DecodeJSON(rr.Body, &lastUpdatedUser) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.CryptConfig.Passphrase.GetStatus()) assert.Equal(t, updateUser.FsConfig.CryptConfig.Passphrase.GetPayload(), lastUpdatedUser.FsConfig.CryptConfig.Passphrase.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.CryptConfig.Passphrase.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestWebUserSFTPFsMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) user.FsConfig.Provider = sdk.SFTPFilesystemProvider user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:22" user.FsConfig.SFTPConfig.Username = "sftpuser" user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("pwd") user.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(testPrivateKeyPwd) user.FsConfig.SFTPConfig.KeyPassphrase = kms.NewPlainSecret(privateKeyPwd) user.FsConfig.SFTPConfig.Fingerprints = []string{sftpPkeyFingerprint} user.FsConfig.SFTPConfig.Prefix = "/home/sftpuser" user.FsConfig.SFTPConfig.DisableCouncurrentReads = true user.FsConfig.SFTPConfig.BufferSize = 5 form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("password", redactedSecret) form.Set("home_dir", user.HomeDir) form.Set("uid", "0") form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("external_auth_cache_time", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") form.Set("allowed_ip", "") form.Set("denied_ip", "") form.Set("fs_provider", "5") form.Set("crypt_passphrase", "") form.Set("directory_patterns[0][pattern_path]", "/dir1") form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") form.Set("directory_patterns[0][pattern_type]", "allowed") form.Set("directory_patterns[1][pattern_path]", "/dir2") form.Set("directory_patterns[1][patterns]", "*.zip") form.Set("directory_patterns[1][pattern_type]", "denied") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") // empty sftpconfig b, contentType, _ := getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("sftp_endpoint", user.FsConfig.SFTPConfig.Endpoint) form.Set("sftp_username", user.FsConfig.SFTPConfig.Username) form.Set("sftp_password", user.FsConfig.SFTPConfig.Password.GetPayload()) form.Set("sftp_private_key", user.FsConfig.SFTPConfig.PrivateKey.GetPayload()) form.Set("sftp_key_passphrase", user.FsConfig.SFTPConfig.KeyPassphrase.GetPayload()) form.Set("sftp_fingerprints", user.FsConfig.SFTPConfig.Fingerprints[0]) form.Set("sftp_prefix", user.FsConfig.SFTPConfig.Prefix) form.Set("sftp_disable_concurrent_reads", "true") form.Set("sftp_equality_check_mode", "true") form.Set("sftp_buffer_size", strconv.FormatInt(user.FsConfig.SFTPConfig.BufferSize, 10)) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var updateUser dataprovider.User err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.SFTPConfig.Password.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.SFTPConfig.Password.GetPayload()) assert.Empty(t, updateUser.FsConfig.SFTPConfig.Password.GetKey()) assert.Empty(t, updateUser.FsConfig.SFTPConfig.Password.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.SFTPConfig.PrivateKey.GetPayload()) assert.Empty(t, updateUser.FsConfig.SFTPConfig.PrivateKey.GetKey()) assert.Empty(t, updateUser.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetStatus()) assert.NotEmpty(t, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetPayload()) assert.Empty(t, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetKey()) assert.Empty(t, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetAdditionalData()) assert.Equal(t, updateUser.FsConfig.SFTPConfig.Prefix, user.FsConfig.SFTPConfig.Prefix) assert.Equal(t, updateUser.FsConfig.SFTPConfig.Username, user.FsConfig.SFTPConfig.Username) assert.Equal(t, updateUser.FsConfig.SFTPConfig.Endpoint, user.FsConfig.SFTPConfig.Endpoint) assert.True(t, updateUser.FsConfig.SFTPConfig.DisableCouncurrentReads) assert.Len(t, updateUser.FsConfig.SFTPConfig.Fingerprints, 1) assert.Equal(t, user.FsConfig.SFTPConfig.BufferSize, updateUser.FsConfig.SFTPConfig.BufferSize) assert.Contains(t, updateUser.FsConfig.SFTPConfig.Fingerprints, sftpPkeyFingerprint) assert.Equal(t, 1, updateUser.FsConfig.SFTPConfig.EqualityCheckMode) // now check that a redacted credentials are not saved form.Set("sftp_password", redactedSecret+" ") form.Set("sftp_private_key", redactedSecret) form.Set("sftp_key_passphrase", redactedSecret) form.Set("sftp_equality_check_mode", "") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var lastUpdatedUser dataprovider.User err = render.DecodeJSON(rr.Body, &lastUpdatedUser) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.SFTPConfig.Password.GetStatus()) assert.Equal(t, updateUser.FsConfig.SFTPConfig.Password.GetPayload(), lastUpdatedUser.FsConfig.SFTPConfig.Password.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.Password.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.Password.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.SFTPConfig.PrivateKey.GetStatus()) assert.Equal(t, updateUser.FsConfig.SFTPConfig.PrivateKey.GetPayload(), lastUpdatedUser.FsConfig.SFTPConfig.PrivateKey.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.PrivateKey.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.SFTPConfig.KeyPassphrase.GetStatus()) assert.Equal(t, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetPayload(), lastUpdatedUser.FsConfig.SFTPConfig.KeyPassphrase.GetPayload()) assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.KeyPassphrase.GetKey()) assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.KeyPassphrase.GetAdditionalData()) assert.Equal(t, 0, lastUpdatedUser.FsConfig.SFTPConfig.EqualityCheckMode) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestWebUserRole(t *testing.T) { role, resp, err := httpdtest.AddRole(getTestRole(), http.StatusCreated) assert.NoError(t, err, string(resp)) a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword a.Role = role.Name a.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers, dataprovider.PermAdminDeleteUsers, dataprovider.PermAdminViewUsers} admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) form.Set("home_dir", user.HomeDir) form.Set("password", user.Password) form.Set("status", strconv.Itoa(user.Status)) form.Set("permissions", "*") form.Set("external_auth_cache_time", "0") form.Set("uid", "0") form.Set("gid", "0") form.Set("max_sessions", "0") form.Set("quota_size", "0") form.Set("quota_files", "0") form.Set("upload_bandwidth", strconv.FormatInt(user.UploadBandwidth, 10)) form.Set("download_bandwidth", strconv.FormatInt(user.DownloadBandwidth, 10)) form.Set("upload_data_transfer", strconv.FormatInt(user.UploadDataTransfer, 10)) form.Set("download_data_transfer", strconv.FormatInt(user.DownloadDataTransfer, 10)) form.Set("total_data_transfer", strconv.FormatInt(user.TotalDataTransfer, 10)) form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "10") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") b, contentType, _ := getMultipartFormData(form, "", "") req, err := http.NewRequest(http.MethodPost, webUserPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Name, user.Role) form.Set("role", "") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, role.Name, user.Role) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) } func TestWebEventAction(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminEventActionPath, webToken) assert.NoError(t, err) action := dataprovider.BaseEventAction{ Name: "web_action_http", Description: "http web action", Type: dataprovider.ActionTypeHTTP, Options: dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: "https://localhost:4567/action", Username: defaultUsername, Headers: []dataprovider.KeyValue{ { Key: "Content-Type", Value: "application/json", }, }, Password: kms.NewPlainSecret(defaultPassword), Timeout: 10, SkipTLSVerify: true, Method: http.MethodPost, QueryParameters: []dataprovider.KeyValue{ { Key: "param1", Value: "value1", }, }, Body: `{"event":"{{.Event}}","name":"{{.Name}}"}`, }, }, } form := make(url.Values) form.Set("name", action.Name) form.Set("description", action.Description) form.Set("fs_action_type", "0") form.Set("type", "a") req, err := http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("type", fmt.Sprintf("%d", action.Type)) form.Set("http_timeout", "b") req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("cmd_timeout", "20") form.Set("pwd_expiration_threshold", "10") form.Set("http_timeout", fmt.Sprintf("%d", action.Options.HTTPConfig.Timeout)) form.Set("http_headers[0][http_header_key]", action.Options.HTTPConfig.Headers[0].Key) form.Set("http_headers[0][http_header_value]", action.Options.HTTPConfig.Headers[0].Value) form.Set("http_headers[1][http_header_key]", action.Options.HTTPConfig.Headers[0].Key) // ignored form.Set("query_parameters[0][http_query_key]", action.Options.HTTPConfig.QueryParameters[0].Key) form.Set("query_parameters[0][http_query_value]", action.Options.HTTPConfig.QueryParameters[0].Value) form.Set("http_body", action.Options.HTTPConfig.Body) form.Set("http_skip_tls_verify", "1") form.Set("http_username", action.Options.HTTPConfig.Username) form.Set("http_password", action.Options.HTTPConfig.Password.GetPayload()) form.Set("http_method", action.Options.HTTPConfig.Method) req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorURLRequired) form.Set("http_endpoint", action.Options.HTTPConfig.Endpoint) req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // a new add will fail req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // list actions req, err = http.NewRequest(http.MethodGet, webAdminEventActionsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webAdminEventActionsPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // render add page req, err = http.NewRequest(http.MethodGet, webAdminEventActionPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // render action page req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventActionPath, action.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // missing action req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventActionPath, action.Name+"1"), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // check the action actionGet, _, err := httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) assert.Equal(t, action.Description, actionGet.Description) assert.Equal(t, action.Options.HTTPConfig.Body, actionGet.Options.HTTPConfig.Body) assert.Equal(t, action.Options.HTTPConfig.Endpoint, actionGet.Options.HTTPConfig.Endpoint) assert.Equal(t, action.Options.HTTPConfig.Headers, actionGet.Options.HTTPConfig.Headers) assert.Equal(t, action.Options.HTTPConfig.Method, actionGet.Options.HTTPConfig.Method) assert.Equal(t, action.Options.HTTPConfig.SkipTLSVerify, actionGet.Options.HTTPConfig.SkipTLSVerify) assert.Equal(t, action.Options.HTTPConfig.Timeout, actionGet.Options.HTTPConfig.Timeout) assert.Equal(t, action.Options.HTTPConfig.Username, actionGet.Options.HTTPConfig.Username) assert.Equal(t, sdkkms.SecretStatusSecretBox, actionGet.Options.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, actionGet.Options.HTTPConfig.Password.GetPayload()) assert.Empty(t, actionGet.Options.HTTPConfig.Password.GetKey()) assert.Empty(t, actionGet.Options.HTTPConfig.Password.GetAdditionalData()) // update and check that the password is preserved and the multipart fields form.Set("http_password", redactedSecret) form.Set("http_body", "") form.Set("http_timeout", "0") form.Del("http_headers[0][http_header_key]") form.Del("http_headers[0][http_header_val]") form.Set("multipart_body[0][http_part_name]", "part1") form.Set("multipart_body[0][http_part_file]", "{{.VirtualPath}}") form.Set("multipart_body[0][http_part_body]", "") form.Set("multipart_body[0][http_part_headers]", "X-MyHeader: a:b,c") form.Set("multipart_body[12][http_part_name]", "part2") form.Set("multipart_body[12][http_part_headers]", "Content-Type:application/json \r\n") form.Set("multipart_body[12][http_part_body]", "{{.ObjectData}}") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) dbAction, err := dataprovider.EventActionExists(action.Name) assert.NoError(t, err) err = dbAction.Options.HTTPConfig.Password.Decrypt() assert.NoError(t, err) assert.Equal(t, defaultPassword, dbAction.Options.HTTPConfig.Password.GetPayload()) assert.Empty(t, dbAction.Options.HTTPConfig.Body) assert.Equal(t, 0, dbAction.Options.HTTPConfig.Timeout) if assert.Len(t, dbAction.Options.HTTPConfig.Parts, 2) { assert.Equal(t, "part1", dbAction.Options.HTTPConfig.Parts[0].Name) assert.Equal(t, "/{{.VirtualPath}}", dbAction.Options.HTTPConfig.Parts[0].Filepath) assert.Empty(t, dbAction.Options.HTTPConfig.Parts[0].Body) assert.Equal(t, "X-MyHeader", dbAction.Options.HTTPConfig.Parts[0].Headers[0].Key) assert.Equal(t, "a:b,c", dbAction.Options.HTTPConfig.Parts[0].Headers[0].Value) assert.Equal(t, "part2", dbAction.Options.HTTPConfig.Parts[1].Name) assert.Equal(t, "{{.ObjectData}}", dbAction.Options.HTTPConfig.Parts[1].Body) assert.Empty(t, dbAction.Options.HTTPConfig.Parts[1].Filepath) assert.Equal(t, "Content-Type", dbAction.Options.HTTPConfig.Parts[1].Headers[0].Key) assert.Equal(t, "application/json", dbAction.Options.HTTPConfig.Parts[1].Headers[0].Value) } // change action type action.Type = dataprovider.ActionTypeCommand action.Options.CmdConfig = dataprovider.EventActionCommandConfig{ Cmd: filepath.Join(os.TempDir(), "cmd"), Args: []string{"arg1", "arg2"}, Timeout: 20, EnvVars: []dataprovider.KeyValue{ { Key: "key", Value: "val", }, }, } dataprovider.EnabledActionCommands = []string{action.Options.CmdConfig.Cmd} defer func() { dataprovider.EnabledActionCommands = nil }() form.Set("type", fmt.Sprintf("%d", action.Type)) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorCommandRequired) form.Set("cmd_path", action.Options.CmdConfig.Cmd) form.Set("cmd_timeout", "a") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("cmd_timeout", fmt.Sprintf("%d", action.Options.CmdConfig.Timeout)) form.Set("env_vars[0][cmd_env_key]", action.Options.CmdConfig.EnvVars[0].Key) form.Set("env_vars[0][cmd_env_value]", action.Options.CmdConfig.EnvVars[0].Value) form.Set("cmd_arguments", "arg1 ,arg2 ") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // update a missing action req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name+"1"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // update with no csrf token form.Del(csrfFormToken) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) // check the update actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) assert.Equal(t, action.Options.CmdConfig.Cmd, actionGet.Options.CmdConfig.Cmd) assert.Equal(t, action.Options.CmdConfig.Args, actionGet.Options.CmdConfig.Args) assert.Equal(t, action.Options.CmdConfig.Timeout, actionGet.Options.CmdConfig.Timeout) assert.Equal(t, action.Options.CmdConfig.EnvVars, actionGet.Options.CmdConfig.EnvVars) assert.Equal(t, dataprovider.EventActionHTTPConfig{}, actionGet.Options.HTTPConfig) assert.Equal(t, dataprovider.EventActionPasswordExpiration{}, actionGet.Options.PwdExpirationConfig) // change action type again action.Type = dataprovider.ActionTypeEmail action.Options.EmailConfig = dataprovider.EventActionEmailConfig{ Recipients: []string{"address1@example.com", "address2@example.com"}, Bcc: []string{"address3@example.com"}, Subject: "subject", ContentType: 1, Body: "body", Attachments: []string{"/file1.txt", "/file2.txt"}, } form.Set("type", fmt.Sprintf("%d", action.Type)) form.Set("email_recipients", "address1@example.com, address2@example.com") form.Set("email_bcc", "address3@example.com") form.Set("email_subject", action.Options.EmailConfig.Subject) form.Set("email_content_type", fmt.Sprintf("%d", action.Options.EmailConfig.ContentType)) form.Set("email_body", action.Options.EmailConfig.Body) form.Set("email_attachments", "file1.txt, file2.txt") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the update actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) assert.Equal(t, action.Options.EmailConfig.Recipients, actionGet.Options.EmailConfig.Recipients) assert.Equal(t, action.Options.EmailConfig.Bcc, actionGet.Options.EmailConfig.Bcc) assert.Equal(t, action.Options.EmailConfig.Subject, actionGet.Options.EmailConfig.Subject) assert.Equal(t, action.Options.EmailConfig.ContentType, actionGet.Options.EmailConfig.ContentType) assert.Equal(t, action.Options.EmailConfig.Body, actionGet.Options.EmailConfig.Body) assert.Equal(t, action.Options.EmailConfig.Attachments, actionGet.Options.EmailConfig.Attachments) assert.Equal(t, dataprovider.EventActionHTTPConfig{}, actionGet.Options.HTTPConfig) assert.Empty(t, actionGet.Options.CmdConfig.Cmd) assert.Equal(t, 0, actionGet.Options.CmdConfig.Timeout) assert.Len(t, actionGet.Options.CmdConfig.EnvVars, 0) // change action type to data retention check action.Type = dataprovider.ActionTypeDataRetentionCheck form.Set("type", fmt.Sprintf("%d", action.Type)) form.Set("data_retention[10][folder_retention_path]", "p1") form.Set("data_retention[10][folder_retention_val]", "a") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("data_retention[10][folder_retention_val]", "24") form.Set("data_retention[10][folder_retention_options][]", "1") form.Set("data_retention[11][folder_retention_path]", "../p2") form.Set("data_retention[11][folder_retention_val]", "48") form.Set("data_retention[11][folder_retention_options][]", "1") form.Set("data_retention[13][folder_retention_options][]", "1") // ignored req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the update actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) if assert.Len(t, actionGet.Options.RetentionConfig.Folders, 2) { for _, folder := range actionGet.Options.RetentionConfig.Folders { switch folder.Path { case "/p1": assert.Equal(t, 24, folder.Retention) assert.True(t, folder.DeleteEmptyDirs) case "/p2": assert.Equal(t, 48, folder.Retention) assert.True(t, folder.DeleteEmptyDirs) default: t.Errorf("unexpected folder path %v", folder.Path) } } } action.Type = dataprovider.ActionTypeFilesystem action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionMkdirs, MkDirs: []string{"a ", " a/b"}, } form.Set("type", fmt.Sprintf("%d", action.Type)) form.Set("fs_mkdir_paths", strings.Join(action.Options.FsConfig.MkDirs, ",")) form.Set("fs_action_type", "invalid") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("fs_action_type", fmt.Sprintf("%d", action.Options.FsConfig.Type)) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the update actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) if assert.Len(t, actionGet.Options.FsConfig.MkDirs, 2) { for _, dir := range actionGet.Options.FsConfig.MkDirs { switch dir { case "/a": case "/a/b": default: t.Errorf("unexpected dir path %v", dir) } } } action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionExist, Exist: []string{"b ", " c/d"}, } form.Set("fs_action_type", fmt.Sprintf("%d", action.Options.FsConfig.Type)) form.Set("fs_exist_paths", strings.Join(action.Options.FsConfig.Exist, ",")) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the update actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) if assert.Len(t, actionGet.Options.FsConfig.Exist, 2) { for _, p := range actionGet.Options.FsConfig.Exist { switch p { case "/b": case "/c/d": default: t.Errorf("unexpected path %v", p) } } } action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionRename, Renames: []dataprovider.RenameConfig{ { KeyValue: dataprovider.KeyValue{ Key: "/src", Value: "/target", }, }, }, } form.Set("fs_action_type", fmt.Sprintf("%d", action.Options.FsConfig.Type)) form.Set("fs_rename[0][fs_rename_source]", action.Options.FsConfig.Renames[0].Key) form.Set("fs_rename[0][fs_rename_target]", action.Options.FsConfig.Renames[0].Value) form.Set("fs_rename[0][fs_rename_options][]", "1") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the update actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) if assert.Len(t, actionGet.Options.FsConfig.Renames, 1) { assert.True(t, actionGet.Options.FsConfig.Renames[0].UpdateModTime) } action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionCopy, Copy: []dataprovider.KeyValue{ { Key: "/copy_src", Value: "/copy_target", }, }, } form.Set("fs_action_type", fmt.Sprintf("%d", action.Options.FsConfig.Type)) form.Set("fs_copy[0][fs_copy_source]", action.Options.FsConfig.Copy[0].Key) form.Set("fs_copy[0][fs_copy_target]", action.Options.FsConfig.Copy[0].Value) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the update actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) assert.Len(t, actionGet.Options.FsConfig.Copy, 1) action.Type = dataprovider.ActionTypePasswordExpirationCheck action.Options.PwdExpirationConfig.Threshold = 15 form.Set("type", fmt.Sprintf("%d", action.Type)) form.Set("pwd_expiration_threshold", "a") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("pwd_expiration_threshold", strconv.Itoa(action.Options.PwdExpirationConfig.Threshold)) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) assert.Equal(t, action.Options.PwdExpirationConfig.Threshold, actionGet.Options.PwdExpirationConfig.Threshold) assert.Equal(t, 0, actionGet.Options.CmdConfig.Timeout) assert.Len(t, actionGet.Options.CmdConfig.EnvVars, 0) action.Type = dataprovider.ActionTypeUserInactivityCheck action.Options.UserInactivityConfig = dataprovider.EventActionUserInactivity{ DisableThreshold: 10, DeleteThreshold: 15, } form.Set("type", fmt.Sprintf("%d", action.Type)) form.Set("inactivity_disable_threshold", strconv.Itoa(action.Options.UserInactivityConfig.DisableThreshold)) form.Set("inactivity_delete_threshold", strconv.Itoa(action.Options.UserInactivityConfig.DeleteThreshold)) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) assert.Equal(t, 0, actionGet.Options.PwdExpirationConfig.Threshold) assert.Equal(t, action.Options.UserInactivityConfig.DisableThreshold, actionGet.Options.UserInactivityConfig.DisableThreshold) assert.Equal(t, action.Options.UserInactivityConfig.DeleteThreshold, actionGet.Options.UserInactivityConfig.DeleteThreshold) action.Type = dataprovider.ActionTypeIDPAccountCheck form.Set("type", fmt.Sprintf("%d", action.Type)) form.Set("idp_mode", "1") form.Set("idp_user", `{"username":"user"}`) form.Set("idp_admin", `{"username":"admin"}`) form.Set("pwd_expiration_threshold", strconv.Itoa(action.Options.PwdExpirationConfig.Threshold)) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, action.Type, actionGet.Type) assert.Equal(t, 1, actionGet.Options.IDPConfig.Mode) assert.Contains(t, actionGet.Options.IDPConfig.TemplateUser, `"user"`) assert.Contains(t, actionGet.Options.IDPConfig.TemplateAdmin, `"admin"`) req, err = http.NewRequest(http.MethodDelete, path.Join(webAdminEventActionPath, action.Name), nil) assert.NoError(t, err) setBearerForReq(req, apiToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) req, err = http.NewRequest(http.MethodDelete, path.Join(webAdminEventActionPath, action.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webAdminEventActionsPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Equal(t, `[]`, rr.Body.String()) } func TestWebEventRule(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminEventRulePath, webToken) assert.NoError(t, err) a := dataprovider.BaseEventAction{ Name: "web_action", Type: dataprovider.ActionTypeFilesystem, Options: dataprovider.BaseEventActionOptions{ FsConfig: dataprovider.EventActionFilesystemConfig{ Type: dataprovider.FilesystemActionExist, Exist: []string{"/dir1"}, }, }, } action, _, err := httpdtest.AddEventAction(a, http.StatusCreated) assert.NoError(t, err) rule := dataprovider.EventRule{ Name: "test_web_rule", Status: 1, Description: "rule added using web API", Trigger: dataprovider.EventTriggerSchedule, Conditions: dataprovider.EventConditions{ Schedules: []dataprovider.Schedule{ { Hours: "0", DayOfWeek: "*", DayOfMonth: "*", Month: "*", }, }, Options: dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "u*", InverseMatch: true, }, }, GroupNames: []dataprovider.ConditionPattern{ { Pattern: "g*", InverseMatch: true, }, }, RoleNames: []dataprovider.ConditionPattern{ { Pattern: "r*", InverseMatch: true, }, }, }, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action.Name, }, Order: 1, }, }, } form := make(url.Values) form.Set("name", rule.Name) form.Set("description", rule.Description) form.Set("status", "a") req, err := http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("status", fmt.Sprintf("%d", rule.Status)) form.Set("trigger", "a") req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("trigger", fmt.Sprintf("%d", rule.Trigger)) form.Set("schedules[0][schedule_hour]", rule.Conditions.Schedules[0].Hours) form.Set("schedules[0][schedule_day_of_week]", rule.Conditions.Schedules[0].DayOfWeek) form.Set("schedules[0][schedule_day_of_month]", rule.Conditions.Schedules[0].DayOfMonth) form.Set("schedules[0][schedule_month]", rule.Conditions.Schedules[0].Month) form.Set("name_filters[0][name_pattern]", rule.Conditions.Options.Names[0].Pattern) form.Set("name_filters[0][type_name_pattern]", "inverse") form.Set("group_name_filters[0][group_name_pattern]", rule.Conditions.Options.GroupNames[0].Pattern) form.Set("group_name_filters[0][type_group_name_pattern]", "inverse") form.Set("role_name_filters[0][role_name_pattern]", rule.Conditions.Options.RoleNames[0].Pattern) form.Set("role_name_filters[0][type_role_name_pattern]", "inverse") req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidMinSize) form.Set("fs_min_size", "0") req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidMaxSize) form.Set("fs_max_size", "0") form.Set("actions[0][action_name]", action.Name) req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // a new add will fail req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // list rules req, err = http.NewRequest(http.MethodGet, webAdminEventRulesPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webAdminEventRulesPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // render add page req, err = http.NewRequest(http.MethodGet, webAdminEventRulePath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // render rule page req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventRulePath, rule.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // missing rule req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventRulePath, rule.Name+"1"), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // check the rule ruleGet, _, err := httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, rule.Trigger, ruleGet.Trigger) assert.Equal(t, rule.Status, ruleGet.Status) assert.Equal(t, rule.Description, ruleGet.Description) assert.Equal(t, rule.Conditions, ruleGet.Conditions) if assert.Len(t, ruleGet.Actions, 1) { assert.Equal(t, rule.Actions[0].Name, ruleGet.Actions[0].Name) assert.Equal(t, rule.Actions[0].Order, ruleGet.Actions[0].Order) } // change rule trigger and status rule.Status = 0 rule.Trigger = dataprovider.EventTriggerFsEvent rule.Conditions = dataprovider.EventConditions{ FsEvents: []string{"upload", "download"}, Options: dataprovider.ConditionOptions{ Names: []dataprovider.ConditionPattern{ { Pattern: "u*", InverseMatch: true, }, }, GroupNames: []dataprovider.ConditionPattern{ { Pattern: "g*", InverseMatch: true, }, }, RoleNames: []dataprovider.ConditionPattern{ { Pattern: "r*", InverseMatch: true, }, }, FsPaths: []dataprovider.ConditionPattern{ { Pattern: "/subdir/*.txt", }, }, Protocols: []string{common.ProtocolSFTP, common.ProtocolHTTP}, MinFileSize: 1024 * 1024, MaxFileSize: 5 * 1024 * 1024, }, } form.Set("status", fmt.Sprintf("%d", rule.Status)) form.Set("trigger", fmt.Sprintf("%d", rule.Trigger)) for _, event := range rule.Conditions.FsEvents { form.Add("fs_events", event) } form.Set("path_filters[0][fs_path_pattern]", rule.Conditions.Options.FsPaths[0].Pattern) for _, protocol := range rule.Conditions.Options.Protocols { form.Add("fs_protocols", protocol) } form.Set("fs_min_size", fmt.Sprintf("%d", rule.Conditions.Options.MinFileSize)) form.Set("fs_max_size", fmt.Sprintf("%d", rule.Conditions.Options.MaxFileSize)) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the rule ruleGet, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, rule.Status, ruleGet.Status) assert.Equal(t, rule.Trigger, ruleGet.Trigger) assert.Equal(t, rule.Description, ruleGet.Description) assert.Equal(t, rule.Conditions, ruleGet.Conditions) if assert.Len(t, ruleGet.Actions, 1) { assert.Equal(t, rule.Actions[0].Name, ruleGet.Actions[0].Name) assert.Equal(t, rule.Actions[0].Order, ruleGet.Actions[0].Order) } rule.Trigger = dataprovider.EventTriggerIDPLogin form.Set("trigger", fmt.Sprintf("%d", rule.Trigger)) form.Set("idp_login_event", "1") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the rule ruleGet, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, rule.Trigger, ruleGet.Trigger) assert.Equal(t, 1, ruleGet.Conditions.IDPLoginEvent) form.Set("idp_login_event", "2") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the rule ruleGet, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, rule.Trigger, ruleGet.Trigger) assert.Equal(t, 2, ruleGet.Conditions.IDPLoginEvent) // update a missing rule req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name+"1"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // update with no csrf token form.Del(csrfFormToken) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) // update with no action defined form.Del("actions[0][action_name]") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorRuleActionRequired) // invalid trigger form.Set("trigger", "a") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) req, err = http.NewRequest(http.MethodDelete, path.Join(webAdminEventRulePath, rule.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodDelete, path.Join(webAdminEventActionPath, action.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestWebIPListEntries(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webIPListPath+"/mode", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, webIPListPath+"/mode/a", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webIPListPath, "/1/a"), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webIPListPath+"/1", nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webIPListsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) entry := dataprovider.IPListEntry{ IPOrNet: "12.34.56.78/20", Type: dataprovider.IPListTypeDefender, Mode: dataprovider.ListModeDeny, Description: "note", Protocols: 5, } form := make(url.Values) form.Set("ipornet", entry.IPOrNet) form.Set("description", entry.Description) form.Set("mode", "a") req, err = http.NewRequest(http.MethodPost, webIPListPath+"/mode", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), util.I18nError400Message) req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webIPListPath+"/2", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("mode", "2") form.Set("protocols", "a") form.Add("protocols", "1") form.Add("protocols", "4") req, err = http.NewRequest(http.MethodPost, webIPListPath+"/"+strconv.Itoa(int(entry.Type)), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) entry1, _, err := httpdtest.GetIPListEntry(entry.IPOrNet, dataprovider.IPListTypeDefender, http.StatusOK) assert.NoError(t, err) assert.Equal(t, entry.Description, entry1.Description) assert.Equal(t, entry.Mode, entry1.Mode) assert.Equal(t, entry.Protocols, entry1.Protocols) form.Set("ipornet", "1111.11.11.11") req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorIPInvalid) form.Set("ipornet", entry.IPOrNet) form.Set("mode", "invalid") // ignored for list type 1 req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) entry2, _, err := httpdtest.GetIPListEntry(entry.IPOrNet, dataprovider.IPListTypeAllowList, http.StatusOK) assert.NoError(t, err) assert.Equal(t, entry.Description, entry2.Description) assert.Equal(t, dataprovider.ListModeAllow, entry2.Mode) assert.Equal(t, entry.Protocols, entry2.Protocols) req, err = http.NewRequest(http.MethodGet, path.Join(webIPListPath, "1", url.PathEscape(entry2.IPOrNet)), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("protocols", "1") req, err = http.NewRequest(http.MethodPost, path.Join(webIPListPath, "1", url.PathEscape(entry.IPOrNet)), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) entry2, _, err = httpdtest.GetIPListEntry(entry.IPOrNet, dataprovider.IPListTypeAllowList, http.StatusOK) assert.NoError(t, err) assert.Equal(t, entry.Description, entry2.Description) assert.Equal(t, dataprovider.ListModeAllow, entry2.Mode) assert.Equal(t, 1, entry2.Protocols) form.Del(csrfFormToken) req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1/"+url.PathEscape(entry.IPOrNet), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webIPListPath+"/a/"+url.PathEscape(entry.IPOrNet), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1/"+url.PathEscape(entry.IPOrNet)+"a", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) form.Set("mode", "a") req, err = http.NewRequest(http.MethodPost, webIPListPath+"/2/"+url.PathEscape(entry.IPOrNet), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("mode", "100") req, err = http.NewRequest(http.MethodPost, webIPListPath+"/2/"+url.PathEscape(entry.IPOrNet), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) _, err = httpdtest.RemoveIPListEntry(entry1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveIPListEntry(entry2, http.StatusOK) assert.NoError(t, err) } func TestWebRole(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminRolePath, webToken) assert.NoError(t, err) role := getTestRole() form := make(url.Values) form.Set("name", "") form.Set("description", role.Description) req, err := http.NewRequest(http.MethodPost, webAdminRolePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminRolePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorNameRequired) form.Set("name", role.Name) req, err = http.NewRequest(http.MethodPost, webAdminRolePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // a new add will fail req, err = http.NewRequest(http.MethodPost, webAdminRolePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // list roles req, err = http.NewRequest(http.MethodGet, webAdminRolesPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webAdminRolesPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // render the new role page req, err = http.NewRequest(http.MethodGet, webAdminRolePath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webAdminRolePath, role.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webAdminRolePath, "missing_role"), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) // parse form error req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name)+"?param=p%C4%AO%GH", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) // update role form.Set("description", "new desc") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check the changes role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, "new desc", role.Description) // no CSRF token form.Set(csrfFormToken, "") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) // missing role form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, "missing"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) } func TestAddWebGroup(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webGroupPath, webToken) assert.NoError(t, err) group := getTestGroup() group.UserSettings = dataprovider.GroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{ HomeDir: filepath.Join(os.TempDir(), util.GenerateUniqueID()), Permissions: make(map[string][]string), MaxSessions: 2, QuotaSize: 123, QuotaFiles: 10, UploadBandwidth: 128, DownloadBandwidth: 256, ExpiresIn: 10, }, } form := make(url.Values) form.Set("name", group.Name) form.Set("description", group.Description) form.Set("home_dir", group.UserSettings.HomeDir) b, contentType, err := getMultipartFormData(form, "", "") assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("max_sessions", strconv.FormatInt(int64(group.UserSettings.MaxSessions), 10)) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidQuotaSize) form.Set("quota_files", strconv.FormatInt(int64(group.UserSettings.QuotaFiles), 10)) form.Set("quota_size", strconv.FormatInt(group.UserSettings.QuotaSize, 10)) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("upload_bandwidth", strconv.FormatInt(group.UserSettings.UploadBandwidth, 10)) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("download_bandwidth", strconv.FormatInt(group.UserSettings.DownloadBandwidth, 10)) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("expires_in", strconv.Itoa(group.UserSettings.ExpiresIn)) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidMaxFilesize) form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("external_auth_cache_time", "0") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath+"?b=%2", &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // error parsing the multipart form b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // a new add will fail b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // list groups req, err = http.NewRequest(http.MethodGet, webGroupsPath, nil) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webGroupsPath+jsonAPISuffix, nil) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // render the new group page req, err = http.NewRequest(http.MethodGet, path.Join(webGroupPath, group.Name), nil) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // check the added group groupGet, _, err := httpdtest.GetGroupByName(group.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, group.UserSettings, groupGet.UserSettings) assert.Equal(t, group.Name, groupGet.Name) assert.Equal(t, group.Description, groupGet.Description) // cleanup req, err = http.NewRequest(http.MethodDelete, path.Join(groupPath, group.Name), nil) assert.NoError(t, err) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webGroupPath, group.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestAddWebFoldersMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) assert.NoError(t, err) mappedPath := filepath.Clean(os.TempDir()) folderName := filepath.Base(mappedPath) folderDesc := "a simple desc" form := make(url.Values) form.Set("mapped_path", mappedPath) form.Set("name", folderName) form.Set("description", folderDesc) form.Set("osfs_read_buffer_size", "3") form.Set("osfs_write_buffer_size", "4") b, contentType, err := getMultipartFormData(form, "", "") assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webFolderPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webFolderPath, &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // adding the same folder will fail since the name must be unique b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webFolderPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // invalid form req, err = http.NewRequest(http.MethodPost, webFolderPath, strings.NewReader(form.Encode())) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", "text/plain; boundary=") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // now render the add folder page req, err = http.NewRequest(http.MethodGet, webFolderPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var folder vfs.BaseVirtualFolder req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folder) assert.NoError(t, err) assert.Equal(t, mappedPath, folder.MappedPath) assert.Equal(t, folderName, folder.Name) assert.Equal(t, folderDesc, folder.Description) assert.Equal(t, 3, folder.FsConfig.OSConfig.ReadBufferSize) assert.Equal(t, 4, folder.FsConfig.OSConfig.WriteBufferSize) // cleanup req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestHTTPFsWebFolderMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) assert.NoError(t, err) mappedPath := filepath.Clean(os.TempDir()) folderName := filepath.Base(mappedPath) httpfsConfig := vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: "https://127.0.0.1:9998/api/v1", Username: folderName, SkipTLSVerify: true, }, Password: kms.NewPlainSecret(defaultPassword), APIKey: kms.NewPlainSecret(defaultTokenAuthPass), } form := make(url.Values) form.Set("mapped_path", mappedPath) form.Set("name", folderName) form.Set("fs_provider", "6") form.Set("http_endpoint", httpfsConfig.Endpoint) form.Set("http_username", "%name%") form.Set("http_password", httpfsConfig.Password.GetPayload()) form.Set("http_api_key", httpfsConfig.APIKey.GetPayload()) form.Set("http_skip_tls_verify", "checked") form.Set(csrfFormToken, csrfToken) b, contentType, err := getMultipartFormData(form, "", "") assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webFolderPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check var folder vfs.BaseVirtualFolder req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folder) assert.NoError(t, err) assert.Equal(t, mappedPath, folder.MappedPath) assert.Equal(t, folderName, folder.Name) assert.Equal(t, sdk.HTTPFilesystemProvider, folder.FsConfig.Provider) assert.Equal(t, httpfsConfig.Endpoint, folder.FsConfig.HTTPConfig.Endpoint) assert.Equal(t, httpfsConfig.Username, folder.FsConfig.HTTPConfig.Username) assert.Equal(t, httpfsConfig.SkipTLSVerify, folder.FsConfig.HTTPConfig.SkipTLSVerify) assert.Equal(t, sdkkms.SecretStatusSecretBox, folder.FsConfig.HTTPConfig.Password.GetStatus()) assert.NotEmpty(t, folder.FsConfig.HTTPConfig.Password.GetPayload()) assert.Empty(t, folder.FsConfig.HTTPConfig.Password.GetKey()) assert.Empty(t, folder.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, folder.FsConfig.HTTPConfig.APIKey.GetStatus()) assert.NotEmpty(t, folder.FsConfig.HTTPConfig.APIKey.GetPayload()) assert.Empty(t, folder.FsConfig.HTTPConfig.APIKey.GetKey()) assert.Empty(t, folder.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) // update form.Set("http_password", redactedSecret) form.Set("http_api_key", redactedSecret) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) // check var updateFolder vfs.BaseVirtualFolder req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &updateFolder) assert.NoError(t, err) assert.Equal(t, mappedPath, updateFolder.MappedPath) assert.Equal(t, folderName, updateFolder.Name) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateFolder.FsConfig.HTTPConfig.Password.GetStatus()) assert.Equal(t, folder.FsConfig.HTTPConfig.Password.GetPayload(), updateFolder.FsConfig.HTTPConfig.Password.GetPayload()) assert.Empty(t, updateFolder.FsConfig.HTTPConfig.Password.GetKey()) assert.Empty(t, updateFolder.FsConfig.HTTPConfig.Password.GetAdditionalData()) assert.Equal(t, sdkkms.SecretStatusSecretBox, updateFolder.FsConfig.HTTPConfig.APIKey.GetStatus()) assert.Equal(t, folder.FsConfig.HTTPConfig.APIKey.GetPayload(), updateFolder.FsConfig.HTTPConfig.APIKey.GetPayload()) assert.Empty(t, updateFolder.FsConfig.HTTPConfig.APIKey.GetKey()) assert.Empty(t, updateFolder.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) // cleanup req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestS3WebFolderMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) assert.NoError(t, err) mappedPath := filepath.Clean(os.TempDir()) folderName := filepath.Base(mappedPath) folderDesc := "a simple desc" S3Bucket := "test" S3Region := "eu-west-1" S3AccessKey := "access-key" S3AccessSecret := kms.NewPlainSecret("folder-access-secret") S3SSEKey := kms.NewPlainSecret("folder-sse-key") S3SessionToken := "fake session token" S3RoleARN := "arn:aws:iam::123456789012:user/Development/product_1234/*" S3Endpoint := "http://127.0.0.1:9000/path?b=c" S3StorageClass := "Standard" S3ACL := "public-read-write" S3KeyPrefix := "somedir/subdir/" S3UploadPartSize := 5 S3UploadConcurrency := 4 S3MaxPartDownloadTime := 120 S3MaxPartUploadTime := 60 S3DownloadPartSize := 6 S3DownloadConcurrency := 3 form := make(url.Values) form.Set("mapped_path", mappedPath) form.Set("name", folderName) form.Set("description", folderDesc) form.Set("fs_provider", "1") form.Set("s3_bucket", S3Bucket) form.Set("s3_region", S3Region) form.Set("s3_access_key", S3AccessKey) form.Set("s3_access_secret", S3AccessSecret.GetPayload()) form.Set("s3_sse_customer_key", S3SSEKey.GetPayload()) form.Set("s3_session_token", S3SessionToken) form.Set("s3_role_arn", S3RoleARN) form.Set("s3_storage_class", S3StorageClass) form.Set("s3_acl", S3ACL) form.Set("s3_endpoint", S3Endpoint) form.Set("s3_key_prefix", S3KeyPrefix) form.Set("s3_upload_part_size", strconv.Itoa(S3UploadPartSize)) form.Set("s3_download_part_max_time", strconv.Itoa(S3MaxPartDownloadTime)) form.Set("s3_download_part_size", strconv.Itoa(S3DownloadPartSize)) form.Set("s3_download_concurrency", strconv.Itoa(S3DownloadConcurrency)) form.Set("s3_upload_part_max_time", strconv.Itoa(S3MaxPartUploadTime)) form.Set("s3_upload_concurrency", "a") form.Set(csrfFormToken, csrfToken) b, contentType, err := getMultipartFormData(form, "", "") assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webFolderPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("s3_upload_concurrency", strconv.Itoa(S3UploadConcurrency)) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webFolderPath, &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) var folder vfs.BaseVirtualFolder req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folder) assert.NoError(t, err) assert.Equal(t, mappedPath, folder.MappedPath) assert.Equal(t, folderName, folder.Name) assert.Equal(t, folderDesc, folder.Description) assert.Equal(t, sdk.S3FilesystemProvider, folder.FsConfig.Provider) assert.Equal(t, S3Bucket, folder.FsConfig.S3Config.Bucket) assert.Equal(t, S3Region, folder.FsConfig.S3Config.Region) assert.Equal(t, S3AccessKey, folder.FsConfig.S3Config.AccessKey) assert.NotEmpty(t, folder.FsConfig.S3Config.AccessSecret.GetPayload()) assert.NotEmpty(t, folder.FsConfig.S3Config.SSECustomerKey.GetPayload()) assert.Equal(t, S3Endpoint, folder.FsConfig.S3Config.Endpoint) assert.Equal(t, S3StorageClass, folder.FsConfig.S3Config.StorageClass) assert.Equal(t, S3ACL, folder.FsConfig.S3Config.ACL) assert.Equal(t, S3KeyPrefix, folder.FsConfig.S3Config.KeyPrefix) assert.Equal(t, S3UploadConcurrency, folder.FsConfig.S3Config.UploadConcurrency) assert.Equal(t, int64(S3UploadPartSize), folder.FsConfig.S3Config.UploadPartSize) assert.Equal(t, S3MaxPartDownloadTime, folder.FsConfig.S3Config.DownloadPartMaxTime) assert.Equal(t, S3MaxPartUploadTime, folder.FsConfig.S3Config.UploadPartMaxTime) assert.Equal(t, S3DownloadConcurrency, folder.FsConfig.S3Config.DownloadConcurrency) assert.Equal(t, int64(S3DownloadPartSize), folder.FsConfig.S3Config.DownloadPartSize) assert.False(t, folder.FsConfig.S3Config.ForcePathStyle) assert.False(t, folder.FsConfig.S3Config.SkipTLSVerify) // update S3UploadConcurrency = 10 form.Set("s3_upload_concurrency", "b") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("s3_upload_concurrency", strconv.Itoa(S3UploadConcurrency)) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) folder = vfs.BaseVirtualFolder{} req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folder) assert.NoError(t, err) assert.Equal(t, mappedPath, folder.MappedPath) assert.Equal(t, folderName, folder.Name) assert.Equal(t, folderDesc, folder.Description) assert.Equal(t, sdk.S3FilesystemProvider, folder.FsConfig.Provider) assert.Equal(t, S3Bucket, folder.FsConfig.S3Config.Bucket) assert.Equal(t, S3Region, folder.FsConfig.S3Config.Region) assert.Equal(t, S3AccessKey, folder.FsConfig.S3Config.AccessKey) assert.Equal(t, S3RoleARN, folder.FsConfig.S3Config.RoleARN) assert.NotEmpty(t, folder.FsConfig.S3Config.AccessSecret.GetPayload()) assert.NotEmpty(t, folder.FsConfig.S3Config.SSECustomerKey.GetPayload()) assert.Equal(t, S3Endpoint, folder.FsConfig.S3Config.Endpoint) assert.Equal(t, S3StorageClass, folder.FsConfig.S3Config.StorageClass) assert.Equal(t, S3KeyPrefix, folder.FsConfig.S3Config.KeyPrefix) assert.Equal(t, S3UploadConcurrency, folder.FsConfig.S3Config.UploadConcurrency) assert.Equal(t, int64(S3UploadPartSize), folder.FsConfig.S3Config.UploadPartSize) // cleanup req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestUpdateWebGroupMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webGroupPath, webToken) assert.NoError(t, err) group, _, err := httpdtest.AddGroup(getTestGroup(), http.StatusCreated) assert.NoError(t, err) group.UserSettings = dataprovider.GroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{ HomeDir: filepath.Join(os.TempDir(), util.GenerateUniqueID()), Permissions: make(map[string][]string), }, FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: defaultUsername, BufferSize: 1, }, }, }, } form := make(url.Values) form.Set("name", group.Name) form.Set("description", group.Description) form.Set("home_dir", group.UserSettings.HomeDir) form.Set("max_sessions", strconv.FormatInt(int64(group.UserSettings.MaxSessions), 10)) form.Set("quota_files", strconv.FormatInt(int64(group.UserSettings.QuotaFiles), 10)) form.Set("quota_size", strconv.FormatInt(group.UserSettings.QuotaSize, 10)) form.Set("upload_bandwidth", strconv.FormatInt(group.UserSettings.UploadBandwidth, 10)) form.Set("download_bandwidth", strconv.FormatInt(group.UserSettings.DownloadBandwidth, 10)) form.Set("upload_data_transfer", "0") form.Set("download_data_transfer", "0") form.Set("total_data_transfer", "0") form.Set("max_upload_file_size", "0") form.Set("default_shares_expiration", "0") form.Set("max_shares_expiration", "0") form.Set("expires_in", "0") form.Set("password_expiration", "0") form.Set("password_strength", "0") form.Set("external_auth_cache_time", "0") form.Set("fs_provider", strconv.FormatInt(int64(group.UserSettings.FsConfig.Provider), 10)) form.Set("sftp_endpoint", group.UserSettings.FsConfig.SFTPConfig.Endpoint) form.Set("sftp_username", group.UserSettings.FsConfig.SFTPConfig.Username) b, contentType, err := getMultipartFormData(form, "", "") assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nError500Message) form.Set("sftp_buffer_size", strconv.FormatInt(group.UserSettings.FsConfig.SFTPConfig.BufferSize, 10)) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorFsCredentialsRequired) form.Set("sftp_password", defaultPassword) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, err = http.NewRequest(http.MethodDelete, path.Join(groupPath, group.Name), nil) assert.NoError(t, err) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) assert.NoError(t, err) req.Header.Set("Content-Type", contentType) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestUpdateWebFolderMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) assert.NoError(t, err) folderName := "vfolderupdate" folderDesc := "updated desc" folder := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: filepath.Join(os.TempDir(), "folderupdate"), Description: "dsc", } _, _, err = httpdtest.AddFolder(folder, http.StatusCreated) newMappedPath := folder.MappedPath + "1" assert.NoError(t, err) form := make(url.Values) form.Set("mapped_path", newMappedPath) form.Set("name", folderName) form.Set("description", folderDesc) form.Set(csrfFormToken, "") b, contentType, err := getMultipartFormData(form, "", "") assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) form.Set(csrfFormToken, csrfToken) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusSeeOther, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) setBearerForReq(req, apiToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = render.DecodeJSON(rr.Body, &folder) assert.NoError(t, err) assert.Equal(t, newMappedPath, folder.MappedPath) assert.Equal(t, folderName, folder.Name) assert.Equal(t, folderDesc, folder.Description) // parse form error b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName)+"??a=a%B3%A2%G3", &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName+"1"), &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) form.Set("mapped_path", "arelative/path") b, contentType, err = getMultipartFormData(form, "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) assert.NoError(t, err) setJWTCookieForReq(req, webToken) req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // render update folder page req, err = http.NewRequest(http.MethodGet, path.Join(webFolderPath, folderName), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, path.Join(webFolderPath, folderName+"1"), nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(webFolderPath, folderName), nil) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") req, _ = http.NewRequest(http.MethodDelete, path.Join(webFolderPath, folderName), nil) setJWTCookieForReq(req, apiToken) // api token is not accepted setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webLoginPath, rr.Header().Get("Location")) req, _ = http.NewRequest(http.MethodDelete, path.Join(webFolderPath, folderName), nil) setJWTCookieForReq(req, webToken) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestWebFoldersMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) mappedPath1 := filepath.Join(os.TempDir(), "vfolder1") mappedPath2 := filepath.Join(os.TempDir(), "vfolder2") folderName1 := filepath.Base(mappedPath1) folderName2 := filepath.Base(mappedPath2) folderDesc1 := "vfolder1 desc" folderDesc2 := "vfolder2 desc" folders := []vfs.BaseVirtualFolder{ { Name: folderName1, MappedPath: mappedPath1, Description: folderDesc1, }, { Name: folderName2, MappedPath: mappedPath2, Description: folderDesc2, }, } for _, folder := range folders { folderAsJSON, err := json.Marshal(folder) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) assert.NoError(t, err) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusCreated, rr) } req, err := http.NewRequest(http.MethodGet, folderPath, nil) assert.NoError(t, err) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) var foldersGet []vfs.BaseVirtualFolder err = render.DecodeJSON(rr.Body, &foldersGet) assert.NoError(t, err) numFound := 0 for _, f := range foldersGet { if f.Name == folderName1 { assert.Equal(t, mappedPath1, f.MappedPath) assert.Equal(t, folderDesc1, f.Description) numFound++ } if f.Name == folderName2 { assert.Equal(t, mappedPath2, f.MappedPath) assert.Equal(t, folderDesc2, f.Description) numFound++ } } assert.Equal(t, 2, numFound) req, err = http.NewRequest(http.MethodGet, webFoldersPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webFoldersPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) for _, folder := range folders { req, _ := http.NewRequest(http.MethodDelete, path.Join(folderPath, folder.Name), nil) setBearerForReq(req, apiToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } } func TestAdminForgotPassword(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 3525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword a.Filters.RequirePasswordChange = true admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webAdminForgotPwdPath, nil) assert.NoError(t, err) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webAdminResetPwdPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webLoginPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) form.Set("username", "") // no csrf token req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusForbidden, rr.Code) // empty username form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorUsernameRequired) lastResetCode = "" form.Set("username", altAdminUsername) // disable the admin admin.Status = 0 admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Len(t, lastResetCode, 0) admin.Status = 1 admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) form = make(url.Values) req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusForbidden, rr.Code) // no password form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) // no code form.Set("password", defaultPassword) form.Set("confirm_password", defaultPassword) req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) // disable the admin admin.Status = 0 admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) form.Set("code", lastResetCode) req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) admin.Status = 1 admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) // ok req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("username", altAdminUsername) req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) // not working smtp server smtpCfg = smtp.Config{ Host: "127.0.0.1", Port: 3526, From: "notification@example.com", TemplatesPath: "templates", } err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) form = make(url.Values) form.Set("username", altAdminUsername) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorPwdResetSendEmail) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) form.Set("username", altAdminUsername) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorPwdResetGeneric) req, err = http.NewRequest(http.MethodGet, webAdminForgotPwdPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webAdminResetPwdPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) assert.NoError(t, err) assert.False(t, admin.Filters.RequirePasswordChange) _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) } func TestUserForgotPassword(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 3525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) u := getTestUser() u.Email = "user@test.com" u.Filters.WebClient = []string{sdk.WebClientPasswordResetDisabled} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientForgotPwdPath, nil) assert.NoError(t, err) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientResetPwdPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) form.Set("username", "") // no csrf token req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusForbidden, rr.Code) // empty username form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorUsernameRequired) // user cannot reset the password form.Set("username", user.Username) req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorPwdResetForbidded) user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Hour)) user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // user is expired lastResetCode = "" req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Len(t, lastResetCode, 0) user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour)) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) // no login token form = make(url.Values) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusForbidden, rr.Code) // no password form.Set(csrfFormToken, csrfToken) form.Set("password", "") form.Set("confirm_password", "") req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) // passwords mismatch form.Set("password", altAdminPassword) form.Set("code", lastResetCode) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdNoMatch) // no code form.Del("code") form.Set("confirm_password", altAdminPassword) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) // Invalid login condition form.Set("code", lastResetCode) user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) // ok user.Filters.DeniedProtocols = []string{common.ProtocolFTP} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) lastResetCode = "" req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientForgotPwdPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webClientResetPwdPath, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) // user does not exist anymore form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("code", lastResetCode) form.Set("password", "pwd") form.Set("confirm_password", "pwd") req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) } func TestAPIForgotPassword(t *testing.T) { smtpCfg := smtp.Config{ Host: "127.0.0.1", Port: 3525, From: "notification@example.com", TemplatesPath: "templates", } err := smtpCfg.Initialize(configDir, true) require.NoError(t, err) a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword a.Email = "" admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) // no email, forgot pwd will not work lastResetCode = "" req, err := http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/forgot-password"), nil) assert.NoError(t, err) rr := executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Your account does not have an email address") admin.Email = "admin@test.com" admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/forgot-password"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.GreaterOrEqual(t, len(lastResetCode), 20) // invalid JSON req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/reset-password"), bytes.NewBuffer([]byte(`{`))) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) resetReq := make(map[string]string) resetReq["code"] = lastResetCode resetReq["password"] = defaultPassword asJSON, err := json.Marshal(resetReq) assert.NoError(t, err) // a user cannot use an admin code req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "invalid confirmation code") req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/reset-password"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // the same code cannot be reused req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/reset-password"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "confirmation code not found") admin, err = dataprovider.AdminExists(altAdminUsername) assert.NoError(t, err) match, err := admin.CheckPassword(defaultPassword) assert.NoError(t, err) assert.True(t, match) lastResetCode = "" // now the same for a user u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/forgot-password"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Your account does not have an email address") user.Email = "user@test.com" user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/forgot-password"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.GreaterOrEqual(t, len(lastResetCode), 20) // invalid JSON req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer([]byte(`{`))) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) // remove the reset password permission user.Filters.WebClient = []string{sdk.WebClientPasswordResetDisabled} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) resetReq["code"] = lastResetCode resetReq["password"] = altAdminPassword asJSON, err = json.Marshal(resetReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "you are not allowed to reset your password") user.Filters.WebClient = []string{sdk.WebClientSharesDisabled} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) // the same code cannot be reused req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "confirmation code not found") user, err = dataprovider.UserExists(defaultUsername, "") assert.NoError(t, err) err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(altAdminPassword)) assert.NoError(t, err) lastResetCode = "" // a request for a missing admin/user will be silently ignored req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, "missing-admin", "/forgot-password"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Empty(t, lastResetCode) req, err = http.NewRequest(http.MethodPost, path.Join(userPath, "missing-user", "/forgot-password"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Empty(t, lastResetCode) lastResetCode = "" req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/forgot-password"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.GreaterOrEqual(t, len(lastResetCode), 20) smtpCfg = smtp.Config{} err = smtpCfg.Initialize(configDir, true) require.NoError(t, err) // without an smtp configuration reset password is not available req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/forgot-password"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "No SMTP configuration") req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/forgot-password"), nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "No SMTP configuration") _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) // the admin does not exist anymore resetReq["code"] = lastResetCode resetReq["password"] = altAdminPassword asJSON, err = json.Marshal(resetReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/reset-password"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "unable to associate the confirmation code with an existing admin") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestProviderClosedMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, token) assert.NoError(t, err) // create a role admin role, resp, err := httpdtest.AddRole(getTestRole(), http.StatusCreated) assert.NoError(t, err, string(resp)) a := getTestAdmin() a.Username = altAdminUsername a.Password = altAdminPassword a.Role = role.Name a.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers, dataprovider.PermAdminDeleteUsers, dataprovider.PermAdminViewUsers} admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) assert.NoError(t, err) altToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) dataprovider.Close() testReq := make(map[string]any) testReq["password"] = redactedSecret asJSON, err := json.Marshal(testReq) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, path.Join(webConfigsPath, "smtp", "test"), bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("X-CSRF-TOKEN", csrfToken) rr := executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) testReq["base_redirect_url"] = "http://localhost" testReq["client_secret"] = redactedSecret asJSON, err = json.Marshal(testReq) assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("X-CSRF-TOKEN", csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, webConfigsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodPost, webConfigsPath, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) getJSONFolders := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, _ := http.NewRequest(http.MethodGet, webFoldersPath+jsonAPISuffix, nil) setJWTCookieForReq(req, token) executeRequest(req) } getJSONFolders() getJSONGroups := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, _ := http.NewRequest(http.MethodGet, webGroupsPath+jsonAPISuffix, nil) setJWTCookieForReq(req, token) executeRequest(req) } getJSONGroups() getJSONUsers := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, _ := http.NewRequest(http.MethodGet, webUsersPath+jsonAPISuffix, nil) setJWTCookieForReq(req, token) executeRequest(req) } getJSONUsers() req, _ = http.NewRequest(http.MethodGet, webUserPath+"/0", nil) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", "test") req, _ = http.NewRequest(http.MethodPost, webUserPath+"/0", strings.NewReader(form.Encode())) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, _ = http.NewRequest(http.MethodGet, path.Join(webAdminPath, defaultTokenAuthUser), nil) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), strings.NewReader(form.Encode())) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) getJSONAdmins := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, _ := http.NewRequest(http.MethodGet, webAdminsPath+jsonAPISuffix, nil) setJWTCookieForReq(req, token) executeRequest(req) } getJSONAdmins() req, _ = http.NewRequest(http.MethodGet, path.Join(webFolderPath, defaultTokenAuthUser), nil) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, _ = http.NewRequest(http.MethodPost, path.Join(webFolderPath, defaultTokenAuthUser), strings.NewReader(form.Encode())) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, _ = http.NewRequest(http.MethodPost, webUserPath, strings.NewReader(form.Encode())) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, _ = http.NewRequest(http.MethodPost, webUserPath, strings.NewReader(form.Encode())) setJWTCookieForReq(req, altToken) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodGet, webIPListPath+"/1/a", nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1/a", nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) getJSONRoles := func() { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() req, err := http.NewRequest(http.MethodGet, webAdminRolesPath+jsonAPISuffix, nil) assert.NoError(t, err) setJWTCookieForReq(req, token) executeRequest(req) } getJSONRoles() req, err = http.NewRequest(http.MethodGet, path.Join(webAdminRolePath, role.Name), nil) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name), strings.NewReader(form.Encode())) assert.NoError(t, err) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.BackupsPath = backupsPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) if config.GetProviderConf().Driver != dataprovider.MemoryDataProviderName { _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveRole(role, http.StatusOK) assert.NoError(t, err) } } func TestWebConnectionsMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webConnectionsPath, nil) setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, _ = http.NewRequest(http.MethodDelete, path.Join(webConnectionsPath, "id"), nil) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") req, _ = http.NewRequest(http.MethodDelete, path.Join(webConnectionsPath, "id"), nil) setJWTCookieForReq(req, token) setCSRFHeaderForReq(req, "csrfToken") rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, token) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodDelete, path.Join(webConnectionsPath, "id"), nil) setJWTCookieForReq(req, token) setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) } func TestGetWebStatusMock(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.RateLimitersConfig = []common.RateLimiterConfig{ { Average: 1, Period: 1000, Burst: 1, Type: 1, Protocols: []string{common.ProtocolFTP}, }, } err := common.Initialize(cfg, 0) assert.NoError(t, err) token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webStatusPath, nil) setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestStaticFilesMock(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "/static/favicon.png", nil) assert.NoError(t, err) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, "/openapi/openapi.yaml", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) req, err = http.NewRequest(http.MethodGet, "/static", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusMovedPermanently, rr) location := rr.Header().Get("Location") assert.Equal(t, "/static/", location) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, "/openapi", nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusMovedPermanently, rr) location = rr.Header().Get("Location") assert.Equal(t, "/openapi/", location) req, err = http.NewRequest(http.MethodGet, location, nil) assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) } func TestPasswordChangeRequired(t *testing.T) { user := getTestUser() assert.False(t, user.MustChangePassword()) user.Filters.RequirePasswordChange = true assert.True(t, user.MustChangePassword()) user.Filters.RequirePasswordChange = false assert.False(t, user.MustChangePassword()) user.Filters.PasswordExpiration = 2 user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now()) assert.False(t, user.MustChangePassword()) user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now().Add(49 * time.Hour)) assert.False(t, user.MustChangePassword()) user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now().Add(-49 * time.Hour)) assert.True(t, user.MustChangePassword()) } func TestPasswordExpiresIn(t *testing.T) { user := getTestUser() user.Filters.PasswordExpiration = 30 user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now().Add(-15*24*time.Hour + 1*time.Hour)) res := user.PasswordExpiresIn() assert.Equal(t, 15, res) user.Filters.PasswordExpiration = 15 res = user.PasswordExpiresIn() assert.Equal(t, 1, res) user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now().Add(-15*24*time.Hour - 1*time.Hour)) res = user.PasswordExpiresIn() assert.Equal(t, 0, res) user.Filters.PasswordExpiration = 5 res = user.PasswordExpiresIn() assert.Equal(t, -10, res) } func TestSecondFactorRequirements(t *testing.T) { user := getTestUser() user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP, common.ProtocolSSH} assert.True(t, user.MustSetSecondFactor()) assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolFTP)) assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolHTTP)) assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolSSH)) user.Filters.TOTPConfig.Enabled = true assert.True(t, user.MustSetSecondFactor()) assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolFTP)) assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolHTTP)) assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolSSH)) user.Filters.TOTPConfig.Protocols = []string{common.ProtocolHTTP} assert.True(t, user.MustSetSecondFactor()) assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolFTP)) assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolHTTP)) assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolSSH)) user.Filters.TOTPConfig.Protocols = []string{common.ProtocolHTTP, common.ProtocolSSH} assert.False(t, user.MustSetSecondFactor()) assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolFTP)) assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolHTTP)) assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolSSH)) } func TestIsNameValid(t *testing.T) { tests := []struct { name string input string expected bool }{ {"simple name", "user", true}, {"alphanumeric", "User123", true}, {"unicode allowed", "你好", true}, {"emoji allowed", "user😊", true}, {"name with dot", "file.txt", true}, {"name with multiple dots", "archive.tar.gz", true}, {"control char", "abc\u0001", false}, {"newline", "abc\n", false}, {"tab", "abc\t", false}, {"slash", "user/name", false}, {"backslash", "user\\name", false}, {"colon", "user:name", false}, {"single dot", ".", false}, {"double dot", "..", false}, {"dot with suffix allowed", ".hidden", true}, {"name ending with dot", "file.", false}, {"name ending with space", "file ", false}, {"CON", "CON", false}, {"con lowercase", "con", false}, {"con with extension", "con.txt", false}, {"LPT1", "LPT1", false}, {"lpt1 lowercase", "lpt1", false}, {"COM5 uppercase", "COM5", false}, {"com9 with extension", "com9.log", false}, {"NUL", "NUL", false}, {"Valid because suffix changes base", "con123", true}, {"base name split", "aux.pdf", false}, {"valid long name", "auxiliary", true}, {"space only", " ", false}, {"dot inside", "ab.cd.ef", true}, {"unicode that ends with dot", "你好.", false}, {"unicode that ends with space", "你好 ", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := util.IsNameValid(tt.input) if result != tt.expected { t.Errorf("IsNameValid(%q) = %v, expected %v", tt.input, result, tt.expected) } }) } } func startOIDCMockServer() { go func() { http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintf(w, "OK\n") }) http.HandleFunc("/auth/realms/sftpgo/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{"issuer":"http://127.0.0.1:11111/auth/realms/sftpgo","authorization_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/auth","token_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/token","introspection_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/token/introspect","userinfo_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/userinfo","end_session_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/logout","frontchannel_logout_session_supported":true,"frontchannel_logout_supported":true,"jwks_uri":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/certs","check_session_iframe":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/login-status-iframe.html","grant_types_supported":["authorization_code","implicit","refresh_token","password","client_credentials","urn:ietf:params:oauth:grant-type:device_code","urn:openid:params:grant-type:ciba"],"response_types_supported":["code","none","id_token","token","id_token token","code id_token","code token","code id_token token"],"subject_types_supported":["public","pairwise"],"id_token_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"id_token_encryption_alg_values_supported":["RSA-OAEP","RSA-OAEP-256","RSA1_5"],"id_token_encryption_enc_values_supported":["A256GCM","A192GCM","A128GCM","A128CBC-HS256","A192CBC-HS384","A256CBC-HS512"],"userinfo_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512","none"],"request_object_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512","none"],"request_object_encryption_alg_values_supported":["RSA-OAEP","RSA-OAEP-256","RSA1_5"],"request_object_encryption_enc_values_supported":["A256GCM","A192GCM","A128GCM","A128CBC-HS256","A192CBC-HS384","A256CBC-HS512"],"response_modes_supported":["query","fragment","form_post","query.jwt","fragment.jwt","form_post.jwt","jwt"],"registration_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/clients-registrations/openid-connect","token_endpoint_auth_methods_supported":["private_key_jwt","client_secret_basic","client_secret_post","tls_client_auth","client_secret_jwt"],"token_endpoint_auth_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"introspection_endpoint_auth_methods_supported":["private_key_jwt","client_secret_basic","client_secret_post","tls_client_auth","client_secret_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"authorization_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"authorization_encryption_alg_values_supported":["RSA-OAEP","RSA-OAEP-256","RSA1_5"],"authorization_encryption_enc_values_supported":["A256GCM","A192GCM","A128GCM","A128CBC-HS256","A192CBC-HS384","A256CBC-HS512"],"claims_supported":["aud","sub","iss","auth_time","name","given_name","family_name","preferred_username","email","acr"],"claim_types_supported":["normal"],"claims_parameter_supported":true,"scopes_supported":["openid","phone","email","web-origins","offline_access","microprofile-jwt","profile","address","roles"],"request_parameter_supported":true,"request_uri_parameter_supported":true,"require_request_uri_registration":true,"code_challenge_methods_supported":["plain","S256"],"tls_client_certificate_bound_access_tokens":true,"revocation_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/revoke","revocation_endpoint_auth_methods_supported":["private_key_jwt","client_secret_basic","client_secret_post","tls_client_auth","client_secret_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"backchannel_logout_supported":true,"backchannel_logout_session_supported":true,"device_authorization_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/auth/device","backchannel_token_delivery_modes_supported":["poll","ping"],"backchannel_authentication_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/ext/ciba/auth","backchannel_authentication_request_signing_alg_values_supported":["PS384","ES384","RS384","ES256","RS256","ES512","PS256","PS512","RS512"],"require_pushed_authorization_requests":false,"pushed_authorization_request_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/ext/par/request","mtls_endpoint_aliases":{"token_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/token","revocation_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/revoke","introspection_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/token/introspect","device_authorization_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/auth/device","registration_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/clients-registrations/openid-connect","userinfo_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/userinfo","pushed_authorization_request_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/ext/par/request","backchannel_authentication_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/ext/ciba/auth"}}`) }) http.HandleFunc("/404", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, "Not found\n") }) if err := http.ListenAndServe(oidcMockAddr, nil); err != nil { logger.ErrorToConsole("could not start HTTP notification server: %v", err) os.Exit(1) } }() waitTCPListening(oidcMockAddr) } func waitForUsersQuotaScan(t *testing.T, token string) { for { var scans []common.ActiveQuotaScan req, _ := http.NewRequest(http.MethodGet, quotaScanPath, nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err := render.DecodeJSON(rr.Body, &scans) if !assert.NoError(t, err, "Error getting active scans") { break } if len(scans) == 0 { break } time.Sleep(100 * time.Millisecond) } } func waitForFoldersQuotaScanPath(t *testing.T, token string) { var scans []common.ActiveVirtualFolderQuotaScan for { req, _ := http.NewRequest(http.MethodGet, quotaScanVFolderPath, nil) setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) err := render.DecodeJSON(rr.Body, &scans) if !assert.NoError(t, err, "Error getting active folders scans") { break } if len(scans) == 0 { break } time.Sleep(100 * time.Millisecond) } } func waitTCPListening(address string) { for { conn, err := net.Dial("tcp", address) if err != nil { logger.WarnToConsole("tcp server %v not listening: %v", address, err) time.Sleep(100 * time.Millisecond) continue } logger.InfoToConsole("tcp server %v now listening", address) conn.Close() break } } func startSMTPServer() { go func() { if err := smtpd.ListenAndServe(smtpServerAddr, func(_ net.Addr, _ string, _ []string, data []byte) error { re := regexp.MustCompile(`code is ".*?"`) code := strings.TrimPrefix(string(re.Find(data)), "code is ") lastResetCode = strings.ReplaceAll(code, "\"", "") return nil }, "SFTPGo test", "localhost"); err != nil { logger.ErrorToConsole("could not start SMTP server: %v", err) os.Exit(1) } }() waitTCPListening(smtpServerAddr) } func getTestAdmin() dataprovider.Admin { return dataprovider.Admin{ Username: defaultTokenAuthUser, Password: defaultTokenAuthPass, Status: 1, Permissions: []string{dataprovider.PermAdminAny}, Email: "admin@example.com", Description: "test admin", } } func getTestGroup() dataprovider.Group { return dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "test_group", Description: "test group description", }, } } func getTestRole() dataprovider.Role { return dataprovider.Role{ Name: "test_role", Description: "test role description", } } func getTestUser() dataprovider.User { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: defaultUsername, Password: defaultPassword, HomeDir: filepath.Join(homeBasePath, defaultUsername), Status: 1, Description: "test user", }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = defaultPerms return user } func getTestSFTPUser() dataprovider.User { u := getTestUser() u.Username = u.Username + "_sftp" u.FsConfig.Provider = sdk.SFTPFilesystemProvider u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) return u } func getUserAsJSON(t *testing.T, user dataprovider.User) []byte { json, err := json.Marshal(user) assert.NoError(t, err) return json } func getCSRFTokenFromInternalPageMock(urlPath, token string) (string, error) { req, err := http.NewRequest(http.MethodGet, urlPath, nil) if err != nil { return "", err } req.RequestURI = urlPath setJWTCookieForReq(req, token) rr := executeRequest(req) if rr.Code != http.StatusOK { return "", fmt.Errorf("unexpected status code: %d", rr.Code) } return getCSRFTokenFromBody(rr.Body) } func getCSRFTokenMock(loginURLPath, remoteAddr string) (string, string, error) { req, err := http.NewRequest(http.MethodGet, loginURLPath, nil) if err != nil { return "", "", err } req.RemoteAddr = remoteAddr rr := executeRequest(req) cookie := rr.Header().Get("Set-Cookie") if cookie == "" { return "", "", errors.New("unable to get login cookie") } token, err := getCSRFTokenFromBody(bytes.NewBuffer(rr.Body.Bytes())) return cookie, token, err } func getCSRFToken(url string) (string, string, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return "", "", err } resp, err := httpclient.GetHTTPClient().Do(req) if err != nil { return "", "", err } cookie := resp.Header.Get("Set-Cookie") if cookie == "" { return "", "", errors.New("no login cookie") } defer resp.Body.Close() token, err := getCSRFTokenFromBody(resp.Body) return cookie, token, err } func getCSRFTokenFromBody(body io.Reader) (string, error) { doc, err := html.Parse(body) if err != nil { return "", err } var csrfToken string var f func(*html.Node) f = func(n *html.Node) { if n.Type == html.ElementNode && n.Data == "input" { var name, value string for _, attr := range n.Attr { if attr.Key == "value" { value = attr.Val } if attr.Key == "name" { name = attr.Val } } if name == csrfFormToken { csrfToken = value return } } for c := n.FirstChild; c != nil; c = c.NextSibling { f(c) } } f(doc) if csrfToken == "" { return "", errors.New("CSRF token not found") } return csrfToken, nil } func getLoginForm(username, password, csrfToken string) url.Values { form := make(url.Values) form.Set("username", username) form.Set("password", password) form.Set(csrfFormToken, csrfToken) return form } func setCSRFHeaderForReq(req *http.Request, csrfToken string) { req.Header.Set("X-CSRF-TOKEN", csrfToken) } func setBearerForReq(req *http.Request, jwtToken string) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", jwtToken)) } func setAPIKeyForReq(req *http.Request, apiKey, username string) { if username != "" { apiKey += "." + username } req.Header.Set("X-SFTPGO-API-KEY", apiKey) } func setLoginCookie(req *http.Request, cookie string) { req.Header.Set("Cookie", cookie) } func setJWTCookieForReq(req *http.Request, jwtToken string) { req.RemoteAddr = defaultRemoteAddr req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", jwtToken)) } func getJWTAPITokenFromTestServer(username, password string) (string, error) { return getJWTAPITokenFromTestServerWithPasscode(username, password, "") } func getJWTAPITokenFromTestServerWithPasscode(username, password, passcode string) (string, error) { req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) req.SetBasicAuth(username, password) if passcode != "" { req.Header.Set("X-SFTPGO-OTP", passcode) } rr := executeRequest(req) if rr.Code != http.StatusOK { return "", fmt.Errorf("unexpected status code %v", rr.Code) } responseHolder := make(map[string]any) err := render.DecodeJSON(rr.Body, &responseHolder) if err != nil { return "", err } return responseHolder["access_token"].(string), nil } func getJWTAPIUserTokenFromTestServer(username, password string) (string, error) { req, _ := http.NewRequest(http.MethodGet, userTokenPath, nil) req.SetBasicAuth(username, password) rr := executeRequest(req) if rr.Code != http.StatusOK { return "", fmt.Errorf("unexpected status code %v", rr.Code) } responseHolder := make(map[string]any) err := render.DecodeJSON(rr.Body, &responseHolder) if err != nil { return "", err } return responseHolder["access_token"].(string), nil } func getJWTWebToken(username, password string) (string, error) { loginCookie, csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) if err != nil { return "", err } form := getLoginForm(username, password, csrfToken) req, _ := http.NewRequest(http.MethodPost, httpBaseURL+webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") client := &http.Client{ Timeout: 10 * time.Second, CheckRedirect: func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse }, } resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusFound { return "", fmt.Errorf("unexpected status code %v", resp.StatusCode) } cookie := resp.Header.Get("Set-Cookie") if strings.HasPrefix(cookie, "jwt=") { return cookie[4:], nil } return "", errors.New("no cookie found") } func getCookieFromResponse(rr *httptest.ResponseRecorder) (string, error) { cookie := strings.Split(rr.Header().Get("Set-Cookie"), ";") if strings.HasPrefix(cookie[0], "jwt=") { return cookie[0][4:], nil } return "", errors.New("no cookie found") } func getJWTWebClientTokenFromTestServerWithAddr(username, password, remoteAddr string) (string, error) { loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, remoteAddr) if err != nil { return "", err } form := getLoginForm(username, password, csrfToken) req, _ := http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = remoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) if rr.Code != http.StatusFound { return "", fmt.Errorf("unexpected status code %v", rr) } return getCookieFromResponse(rr) } func getJWTWebClientTokenFromTestServer(username, password string) (string, error) { loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) if err != nil { return "", err } form := getLoginForm(username, password, csrfToken) req, _ := http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) if rr.Code != http.StatusFound { return "", fmt.Errorf("unexpected status code %v", rr) } return getCookieFromResponse(rr) } func getJWTWebTokenFromTestServer(username, password string) (string, error) { loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) if err != nil { return "", err } form := getLoginForm(username, password, csrfToken) req, _ := http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) if rr.Code != http.StatusFound { return "", fmt.Errorf("unexpected status code %v", rr) } return getCookieFromResponse(rr) } func executeRequest(req *http.Request) *httptest.ResponseRecorder { rr := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) return rr } func checkResponseCode(t *testing.T, expected int, rr *httptest.ResponseRecorder) { assert.Equal(t, expected, rr.Code, rr.Body.String()) } func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { var sftpClient *sftp.Client config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } if user.Password != "" { config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} } else { config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { return conn, sftpClient, err } sftpClient, err = sftp.NewClient(conn) if err != nil { conn.Close() } return conn, sftpClient, err } func createTestFile(path string, size int64) error { baseDir := filepath.Dir(path) if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(baseDir, os.ModePerm) if err != nil { return err } } content := make([]byte, size) if size > 0 { _, err := rand.Read(content) if err != nil { return err } } return os.WriteFile(path, content, os.ModePerm) } func getExitCodeScriptContent(exitCode int) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) return content } func getMultipartFormData(values url.Values, fileFieldName, filePath string) (bytes.Buffer, string, error) { var b bytes.Buffer w := multipart.NewWriter(&b) for k, v := range values { for _, s := range v { if err := w.WriteField(k, s); err != nil { return b, "", err } } } if len(fileFieldName) > 0 && len(filePath) > 0 { fw, err := w.CreateFormFile(fileFieldName, filepath.Base(filePath)) if err != nil { return b, "", err } f, err := os.Open(filePath) if err != nil { return b, "", err } defer f.Close() if _, err = io.Copy(fw, f); err != nil { return b, "", err } } err := w.Close() return b, w.FormDataContentType(), err } func generateTOTPPasscode(secret string) (string, error) { return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ Period: 30, Skew: 1, Digits: otp.DigitsSix, Algorithm: otp.AlgorithmSHA1, }) } func isDbDefenderSupported() bool { // SQLite shares the implementation with other SQL-based provider but it makes no sense // to use it outside test cases switch dataprovider.GetProviderStatus().Driver { case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: return true default: return false } } func createTestPNG(name string, width, height int, imgColor color.Color) error { upLeft := image.Point{0, 0} lowRight := image.Point{width, height} img := image.NewRGBA(image.Rectangle{upLeft, lowRight}) for x := 0; x < width; x++ { for y := 0; y < height; y++ { img.Set(x, y, imgColor) } } f, err := os.Create(name) if err != nil { return err } defer f.Close() return png.Encode(f, img) } func BenchmarkSecretDecryption(b *testing.B) { s := kms.NewPlainSecret("test data") s.SetAdditionalData("username") err := s.Encrypt() require.NoError(b, err) for i := 0; i < b.N; i++ { err = s.Clone().Decrypt() require.NoError(b, err) } } ================================================ FILE: internal/httpd/internal_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "bytes" "context" "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" "html/template" "io" "io/fs" "net/http" "net/http/httptest" "net/url" "os" "path" "path/filepath" "runtime" "strings" "testing" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-jose/go-jose/v4" josejwt "github.com/go-jose/go-jose/v4/jwt" "github.com/klauspost/compress/zip" "github.com/rs/xid" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" "github.com/sftpgo/sdk/plugin/notifier" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/html" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( httpdCert = `-----BEGIN CERTIFICATE----- MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM 3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY /8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI dV4vKmHUzwK/eIx+8Ay3neE= -----END CERTIFICATE-----` httpdKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= -----END EC PRIVATE KEY-----` caCRT = `-----BEGIN CERTIFICATE----- MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U 4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY /tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w kK8u1YiiVenmoQ== -----END CERTIFICATE-----` caKey = `-----BEGIN RSA PRIVATE KEY----- MIIJKgIBAAKCAgEA7WHW216mfi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7x b64rkpdzx1aWetSiCrEyc3D1v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvD OBUYZgtMqHZzpE6xRrqQ84zhyzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKz n/2uEVt33qmO85WtN3RzbSqLCdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj 7B5P5MeamkkogwbExUjdHp3U4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZ De67V/Q8iB2May1k7zBz1ZtbKF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmOb cn8AIfH6smLQrn0C3cs7CYfoNlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLq R/BJTbbyXUB0imne1u00fuzbS7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb7 8x7ivdyXSF5LVQJ1JvhhWu6iM6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpB P8d2jcRZVUVrSXGc2mAGuGOY/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2Z xugCXULtRWJ9p4C9zUl40HEyOQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEA AQKCAgEA4x0OoceG54ZrVxifqVaQd8qw3uRmUKUMIMdfuMlsdideeLO97ynmSlRY 00kGo/I4Lp6mNEjI9gUie9+uBrcUhri4YLcujHCH+YlNnCBDbGjwbe0ds9SLCWaa KztZHMSlW5Q4Bqytgu+MpOnxSgqjlOk+vz9TcGFKVnUkHIkAcqKFJX8gOFxPZA/t Ob1kJaz4kuv5W2Kur/ISKvQtvFvOtQeV0aJyZm8LqXnvS4cPI7yN4329NDU0HyDR y/deqS2aqV4zII3FFqbz8zix/m1xtVQzWCugZGMKrz0iuJMfNeCABb8rRGc6GsZz +465v/kobqgeyyneJ1s5rMFrLp2o+dwmnIVMNsFDUiN1lIZDHLvlgonaUO3IdTZc 9asamFWKFKUMgWqM4zB1vmUO12CKowLNIIKb0L+kf1ixaLLDRGf/f9vLtSHE+oyx lATiS18VNA8+CGsHF6uXMRwf2auZdRI9+s6AAeyRISSbO1khyWKHo+bpOvmPAkDR nknTjbYgkoZOV+mrsU5oxV8s6vMkuvA3rwFhT2gie8pokuACFcCRrZi9MVs4LmUQ u0GYTHvp2WJUjMWBm6XX7Hk3g2HV842qpk/mdtTjNsXws81djtJPn4I/soIXSgXz pY3SvKTuOckP9OZVF0yqKGeZXKpD288PKpC+MAg3GvEJaednagECggEBAPsfLwuP L1kiDjXyMcRoKlrQ6Q/zBGyBmJbZ5uVGa02+XtYtDAzLoVupPESXL0E7+r8ZpZ39 0dV4CEJKpbVS/BBtTEkPpTK5kz778Ib04TAyj+YLhsZjsnuja3T5bIBZXFDeDVDM 0ZaoFoKpIjTu2aO6pzngsgXs6EYbo2MTuJD3h0nkGZsICL7xvT9Mw0P1p2Ftt/hN +jKk3vN220wTWUsq43AePi45VwK+PNP12ZXv9HpWDxlPo3j0nXtgYXittYNAT92u BZbFAzldEIX9WKKZgsWtIzLaASjVRntpxDCTby/nlzQ5dw3DHU1DV3PIqxZS2+Oe KV+7XFWgZ44YjYECggEBAPH+VDu3QSrqSahkZLkgBtGRkiZPkZFXYvU6kL8qf5wO Z/uXMeqHtznAupLea8I4YZLfQim/NfC0v1cAcFa9Ckt9g3GwTSirVcN0AC1iOyv3 /hMZCA1zIyIcuUplNr8qewoX71uPOvCNH0dix77423mKFkJmNwzy4Q+rV+qkRdLn v+AAgh7g5N91pxNd6LQJjoyfi1Ka6rRP2yGXM5v7QOwD16eN4JmExUxX1YQ7uNuX pVS+HRxnBquA+3/DB1LtBX6pa2cUa+LRUmE/NCPHMvJcyuNkYpJKlNTd9vnbfo0H RNSJSWm+aGxDFMjuPjV3JLj2OdKMPwpnXdh2vBZCPpMCggEAM+yTvrEhmi2HgLIO hkz/jP2rYyfdn04ArhhqPLgd0dpuI5z24+Jq/9fzZT9ZfwSW6VK1QwDLlXcXRhXH Q8Hf6smev3CjuORURO61IkKaGWwrAucZPAY7ToNQ4cP9ImDXzMTNPgrLv3oMBYJR V16X09nxX+9NABqnQG/QjdjzDc6Qw7+NZ9f2bvzvI5qMuY2eyW91XbtJ45ThoLfP ymAp03gPxQwL0WT7z85kJ3OrROxzwaPvxU0JQSZbNbqNDPXmFTiECxNDhpRAAWlz 1DC5Vg2l05fkMkyPdtD6nOQWs/CYSfB5/EtxiX/xnBszhvZUIe6KFvuKFIhaJD5h iykagQKCAQEAoBRm8k3KbTIo4ZzvyEq4V/+dF3zBRczx6FkCkYLygXBCNvsQiR2Y BjtI8Ijz7bnQShEoOmeDriRTAqGGrspEuiVgQ1+l2wZkKHRe/aaij/Zv+4AuhH8q uZEYvW7w5Uqbs9SbgQzhp2kjTNy6V8lVnjPLf8cQGZ+9Y9krwktC6T5m/i435WdN 38h7amNP4XEE/F86Eb3rDrZYtgLIoCF4E+iCyxMehU+AGH1uABhls9XAB6vvo+8/ SUp8lEqWWLP0U5KNOtYWfCeOAEiIHDbUq+DYUc4BKtbtV1cx3pzlPTOWw6XBi5Lq jttdL4HyYvnasAQpwe8GcMJqIRyCVZMiwwKCAQEAhQTTS3CC8PwcoYrpBdTjW1ck vVFeF1YbfqPZfYxASCOtdx6wRnnEJ+bjqntagns9e88muxj9UhxSL6q9XaXQBD8+ 2AmKUxphCZQiYFZcTucjQEQEI2nN+nAKgRrUSMMGiR8Ekc2iFrcxBU0dnSohw+aB PbMKVypQCREu9PcDFIp9rXQTeElbaNsIg1C1w/SQjODbmN/QFHTVbRODYqLeX1J/ VcGsykSIq7hv6bjn7JGkr2JTdANbjk9LnMjMdJFsKRYxPKkOQfYred6Hiojp5Sor PW5am8ejnNSPhIfqQp3uV3KhwPDKIeIpzvrB4uPfTjQWhekHCb8cKSWux3flqw== -----END RSA PRIVATE KEY-----` caCRL = `-----BEGIN X509 CRL----- MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV +jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn w7bXJGfJcIMrrKs= -----END X509 CRL-----` client1Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au MU3Bo0A= -----END CERTIFICATE-----` client1Key = `-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi 85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== -----END RSA PRIVATE KEY-----` // client 2 crt is revoked client2Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB 4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO 78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv /ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR E0+r2+4= -----END CERTIFICATE-----` client2Key = `-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu 4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl /owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz 3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D 8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl -----END RSA PRIVATE KEY-----` defaultAdminUsername = "admin" defaultAdminPass = "password" defeaultUsername = "test_user" ) var ( configDir = filepath.Join(".", "..", "..") ) type failingWriter struct { } func (r *failingWriter) Write(_ []byte) (n int, err error) { return 0, errors.New("write error") } func (r *failingWriter) WriteHeader(_ int) {} func (r *failingWriter) Header() http.Header { return make(http.Header) } type failingJoseSigner struct{} func (s *failingJoseSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { return nil, errors.New("sign test error") } func (s *failingJoseSigner) Options() jose.SignerOptions { return jose.SignerOptions{} } func TestShouldBind(t *testing.T) { c := Conf{ Bindings: []Binding{ { Port: 10000, }, }, } require.False(t, c.ShouldBind()) c.Bindings[0].EnableRESTAPI = true require.True(t, c.ShouldBind()) c.Bindings[0].Port = 0 require.False(t, c.ShouldBind()) if runtime.GOOS != osWindows { c.Bindings[0].Address = "/absolute/path" require.True(t, c.ShouldBind()) } } func TestBrandingValidation(t *testing.T) { b := Binding{ Branding: Branding{ WebAdmin: UIBranding{ LogoPath: "path1", DefaultCSS: []string{"my.css"}, }, WebClient: UIBranding{ FaviconPath: "favicon1.ico", DisclaimerPath: "../path2", ExtraCSS: []string{"1.css"}, }, }, } b.checkBranding() assert.Equal(t, "/favicon.png", b.Branding.WebAdmin.FaviconPath) assert.Equal(t, "/path1", b.Branding.WebAdmin.LogoPath) assert.Equal(t, []string{"/my.css"}, b.Branding.WebAdmin.DefaultCSS) assert.Len(t, b.Branding.WebAdmin.ExtraCSS, 0) assert.Equal(t, "/favicon1.ico", b.Branding.WebClient.FaviconPath) assert.Equal(t, path.Join(webStaticFilesPath, "/path2"), b.Branding.WebClient.DisclaimerPath) if assert.Len(t, b.Branding.WebClient.ExtraCSS, 1) { assert.Equal(t, "/1.css", b.Branding.WebClient.ExtraCSS[0]) } b.Branding.WebAdmin.DisclaimerPath = "https://example.com" b.checkBranding() assert.Equal(t, "https://example.com", b.Branding.WebAdmin.DisclaimerPath) } func TestRedactedConf(t *testing.T) { c := Conf{ SigningPassphrase: "passphrase", Setup: SetupConfig{ InstallationCode: "123", }, } redactedField := "[redacted]" redactedConf := c.getRedacted() assert.Equal(t, redactedField, redactedConf.SigningPassphrase) assert.Equal(t, redactedField, redactedConf.Setup.InstallationCode) assert.NotEqual(t, c.SigningPassphrase, redactedConf.SigningPassphrase) assert.NotEqual(t, c.Setup.InstallationCode, redactedConf.Setup.InstallationCode) } func TestGetRespStatus(t *testing.T) { var err error err = util.NewMethodDisabledError("") respStatus := getRespStatus(err) assert.Equal(t, http.StatusForbidden, respStatus) err = fmt.Errorf("generic error") respStatus = getRespStatus(err) assert.Equal(t, http.StatusInternalServerError, respStatus) respStatus = getRespStatus(plugin.ErrNoSearcher) assert.Equal(t, http.StatusNotImplemented, respStatus) } func TestMappedStatusCode(t *testing.T) { err := os.ErrPermission code := getMappedStatusCode(err) assert.Equal(t, http.StatusForbidden, code) err = os.ErrNotExist code = getMappedStatusCode(err) assert.Equal(t, http.StatusNotFound, code) err = common.ErrQuotaExceeded code = getMappedStatusCode(err) assert.Equal(t, http.StatusRequestEntityTooLarge, code) err = os.ErrClosed code = getMappedStatusCode(err) assert.Equal(t, http.StatusInternalServerError, code) err = &http.MaxBytesError{} code = getMappedStatusCode(err) assert.Equal(t, http.StatusRequestEntityTooLarge, code) } func TestGCSWebInvalidFormFile(t *testing.T) { form := make(url.Values) form.Set("username", "test_username") form.Set("fs_provider", "2") req, _ := http.NewRequest(http.MethodPost, webUserPath, strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") err := req.ParseForm() assert.NoError(t, err) _, err = getFsConfigFromPostFields(req) assert.EqualError(t, err, http.ErrNotMultipart.Error()) } func TestBrandingInvalidFormFile(t *testing.T) { form := make(url.Values) req, _ := http.NewRequest(http.MethodPost, webConfigsPath, strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") err := req.ParseForm() assert.NoError(t, err) _, err = getBrandingConfigFromPostFields(req, &dataprovider.BrandingConfigs{}) assert.EqualError(t, err, http.ErrNotMultipart.Error()) } func TestTokenDuration(t *testing.T) { assert.Equal(t, shareTokenDuration, getTokenDuration(tokenAudienceWebShare)) assert.Equal(t, apiTokenDuration, getTokenDuration(tokenAudienceAPI)) assert.Equal(t, apiTokenDuration, getTokenDuration(tokenAudienceAPIUser)) assert.Equal(t, cookieTokenDuration, getTokenDuration(tokenAudienceWebAdmin)) assert.Equal(t, csrfTokenDuration, getTokenDuration(tokenAudienceCSRF)) assert.Equal(t, 20*time.Minute, getTokenDuration("")) updateTokensDuration(30, 660, 360) assert.Equal(t, 30*time.Minute, apiTokenDuration) assert.Equal(t, 11*time.Hour, cookieTokenDuration) assert.Equal(t, 11*time.Hour, csrfTokenDuration) assert.Equal(t, 6*time.Hour, shareTokenDuration) assert.Equal(t, 11*time.Hour, getMaxCookieDuration()) csrfTokenDuration = 1 * time.Hour assert.Equal(t, 11*time.Hour, getMaxCookieDuration()) } func TestVerifyCSRFToken(t *testing.T) { server := httpdServer{} err := server.initializeRouter() require.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) require.NoError(t, err) req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, fs.ErrPermission)) rr := httptest.NewRecorder() tokenString := createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath) assert.NotEmpty(t, tokenString) claims, err := jwt.VerifyToken(server.csrfTokenAuth, tokenString) require.NoError(t, err) assert.Empty(t, claims.Ref) req.Form = url.Values{} req.Form.Set(csrfFormToken, tokenString) err = verifyCSRFToken(req, server.csrfTokenAuth) assert.ErrorIs(t, err, fs.ErrPermission) req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) require.NoError(t, err) req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) req.Form = url.Values{} req.Form.Set(csrfFormToken, tokenString) err = verifyCSRFToken(req, server.csrfTokenAuth) assert.ErrorContains(t, err, "unexpected form token") claims = jwt.NewClaims(tokenAudienceCSRF, "", getTokenDuration(tokenAudienceCSRF)) tokenString, err = josejwt.Signed(server.csrfTokenAuth.Signer()).Claims(claims).Serialize() assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) require.NoError(t, err) req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) req.Form = url.Values{} req.Form.Set(csrfFormToken, tokenString) err = verifyCSRFToken(req, server.csrfTokenAuth) assert.ErrorContains(t, err, "the form token is not valid") } func TestInvalidToken(t *testing.T) { server := httpdServer{} err := server.initializeRouter() require.NoError(t, err) admin := dataprovider.Admin{ Username: "admin", } errFake := errors.New("fake error") asJSON, err := json.Marshal(admin) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodPut, path.Join(adminPath, admin.Username), bytes.NewBuffer(asJSON)) rctx := chi.NewRouteContext() rctx.URLParams.Add("username", admin.Username) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req = req.WithContext(context.WithValue(req.Context(), jwt.ErrorCtxKey, errFake)) rr := httptest.NewRecorder() updateAdmin(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) rr = httptest.NewRecorder() deleteAdmin(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) adminPwd := pwdChange{ CurrentPassword: "old", NewPassword: "new", } asJSON, err = json.Marshal(adminPwd) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, "", bytes.NewBuffer(asJSON)) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req = req.WithContext(context.WithValue(req.Context(), jwt.ErrorCtxKey, errFake)) rr = httptest.NewRecorder() changeAdminPassword(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code) adm := getAdminFromToken(req) assert.Empty(t, adm.Username) rr = httptest.NewRecorder() readUserFolder(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getUserFile(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getUserFilesAsZipStream(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getShares(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getShareByID(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() addShare(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateShare(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteShare(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() generateTOTPSecret(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() saveTOTPConfig(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getRecoveryCodes(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() generateRecoveryCodes(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getUserProfile(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateUserProfile(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getWebTask(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getAdminProfile(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateAdminProfile(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() loadData(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() loadDataFromRequest(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() addUser(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() disableUser2FA(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateUser(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteUser(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getActiveConnections(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() handleCloseConnection(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() server.handleWebRestore(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebAddUserPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebUpdateUserPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebTemplateFolderPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebTemplateUserPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() getAllAdmins(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() getAllUsers(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() addFolder(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateFolder(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getFolderByName(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteFolder(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() server.handleWebAddFolderPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebUpdateFolderPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebGetConnections(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebConfigsPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() addAdmin(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() disableAdmin2FA(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() addAPIKey(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateAPIKey(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteAPIKey(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() addGroup(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateGroup(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getGroupByName(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteGroup(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() addEventAction(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getEventActionByName(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateEventAction(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteEventAction(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getEventRuleByName(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() addEventRule(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateEventRule(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteEventRule(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getUsersQuotaScans(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateUserTransferQuotaUsage(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() doUpdateUserQuotaUsage(rr, req, "", quotaUsage{}) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() doStartUserQuotaScan(rr, req, "") assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getRetentionChecks(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() addRole(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateRole(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteRole(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getUsers(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() getUserByUsername(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() searchFsEvents(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() searchProviderEvents(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() searchLogEvents(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() addIPListEntry(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() updateIPListEntry(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() deleteIPListEntry(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") rr = httptest.NewRecorder() server.handleGetWebUsers(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebUpdateUserGet(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebUpdateRolePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebAddRolePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebAddAdminPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebAddGroupPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebUpdateGroupPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebAddEventActionPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebUpdateEventActionPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebAddEventRulePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebUpdateEventRulePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebUpdateIPListEntryPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() server.handleWebClientTwoFactorRecoveryPost(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() server.handleWebClientTwoFactorPost(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() server.handleWebAdminTwoFactorRecoveryPost(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() server.handleWebAdminTwoFactorPost(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() server.handleWebUpdateIPListEntryPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) form := make(url.Values) req, _ = http.NewRequest(http.MethodPost, webIPListPath+"/1", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rctx = chi.NewRouteContext() rctx.URLParams.Add("type", "1") req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) rr = httptest.NewRecorder() server.handleWebAddIPListEntryPost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) } func TestTokenSignatureValidation(t *testing.T) { tokenValidationMode = 0 server := httpdServer{ binding: Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: true, EnableRESTAPI: true, }, enableWebAdmin: true, enableWebClient: true, enableRESTAPI: true, } err := server.initializeRouter() require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() rr := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, tokenPath, nil) require.NoError(t, err) req.SetBasicAuth(defaultAdminUsername, defaultAdminPass) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) var resp map[string]any err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) accessToken := resp["access_token"] require.NotEmpty(t, accessToken) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, versionPath, nil) require.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) // change the token validation mode tokenValidationMode = 2 rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, versionPath, nil) require.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) // Now update the admin admin, err := dataprovider.AdminExists(defaultAdminUsername) assert.NoError(t, err) err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) // token validation mode is 0, the old token is still valid tokenValidationMode = 0 rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, versionPath, nil) require.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) // change the token validation mode tokenValidationMode = 2 rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, versionPath, nil) require.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) // the token is invalidated, changing the validation mode has no effect tokenValidationMode = 0 rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, versionPath, nil) require.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) userPwd := "pwd" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: defeaultUsername, Password: userPwd, HomeDir: filepath.Join(os.TempDir(), defeaultUsername), Status: 1, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) defer func() { dataprovider.DeleteUser(defeaultUsername, "", "", "") //nolint:errcheck }() tokenValidationMode = 2 req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) require.NoError(t, err) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) loginCookie := strings.Split(rr.Header().Get("Set-Cookie"), ";")[0] assert.NotEmpty(t, loginCookie) csrfToken, err := getCSRFTokenFromBody(rr.Body) assert.NoError(t, err) assert.NotEmpty(t, csrfToken) // Now login form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", defeaultUsername) form.Set("password", userPwd) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusFound, rr.Code) userCookie := strings.Split(rr.Header().Get("Set-Cookie"), ";")[0] assert.NotEmpty(t, userCookie) // Test a WebClient page and a JSON API rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) require.NoError(t, err) req.Header.Set("Cookie", userCookie) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil) require.NoError(t, err) req.Header.Set("Cookie", userCookie) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) csrfToken, err = getCSRFTokenFromBody(rr.Body) assert.NoError(t, err) assert.NotEmpty(t, csrfToken) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, webClientFilePath+"?path=missing.txt", nil) require.NoError(t, err) req.Header.Set("Cookie", userCookie) req.Header.Set(csrfHeaderToken, csrfToken) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) tokenValidationMode = 0 err = dataprovider.DeleteUser(defeaultUsername, "", "", "") assert.NoError(t, err) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, webClientFilePath+"?path=missing.txt", nil) require.NoError(t, err) req.Header.Set("Cookie", userCookie) req.Header.Set(csrfHeaderToken, csrfToken) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) tokenValidationMode = 2 rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, webClientFilePath+"?path=missing.txt", nil) require.NoError(t, err) req.Header.Set("Cookie", userCookie) req.Header.Set(csrfHeaderToken, csrfToken) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusFound, rr.Code) tokenValidationMode = 0 } func TestUpdateWebAdminInvalidClaims(t *testing.T) { server := httpdServer{} err := server.initializeRouter() require.NoError(t, err) rr := httptest.NewRecorder() admin := dataprovider.Admin{ Username: "", Password: "password", } c := &jwt.Claims{ Username: admin.Username, Permissions: admin.Permissions, } c.Subject = admin.GetSignature() token, err := server.tokenAuth.SignWithParams(c, tokenAudienceWebAdmin, "", 10*time.Minute) assert.NoError(t, err) resp := c.BuildTokenResponse(token) req, err := http.NewRequest(http.MethodGet, webAdminPath, nil) assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) parsedToken, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx := req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form := make(url.Values) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath)) form.Set("status", "1") form.Set("default_users_expiration", "30") req, err = http.NewRequest(http.MethodPost, path.Join(webAdminPath, "admin"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) rctx := chi.NewRouteContext() rctx.URLParams.Add("username", "admin") req = req.WithContext(ctx) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) server.handleWebUpdateAdminPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) } func TestUpdateSMTPSecrets(t *testing.T) { currentConfigs := &dataprovider.SMTPConfigs{ OAuth2: dataprovider.SMTPOAuth2{ ClientSecret: kms.NewPlainSecret("client secret"), RefreshToken: kms.NewPlainSecret("refresh token"), }, } redactedClientSecret := kms.NewPlainSecret("secret") redactedRefreshToken := kms.NewPlainSecret("token") redactedClientSecret.SetStatus(sdkkms.SecretStatusRedacted) redactedRefreshToken.SetStatus(sdkkms.SecretStatusRedacted) newConfigs := &dataprovider.SMTPConfigs{ Password: kms.NewPlainSecret("pwd"), OAuth2: dataprovider.SMTPOAuth2{ ClientSecret: redactedClientSecret, RefreshToken: redactedRefreshToken, }, } updateSMTPSecrets(newConfigs, currentConfigs) assert.Nil(t, currentConfigs.Password) assert.NotNil(t, newConfigs.Password) assert.Equal(t, currentConfigs.OAuth2.ClientSecret, newConfigs.OAuth2.ClientSecret) assert.Equal(t, currentConfigs.OAuth2.RefreshToken, newConfigs.OAuth2.RefreshToken) clientSecret := kms.NewPlainSecret("plain secret") refreshToken := kms.NewPlainSecret("plain token") newConfigs = &dataprovider.SMTPConfigs{ Password: kms.NewPlainSecret("pwd"), OAuth2: dataprovider.SMTPOAuth2{ ClientSecret: clientSecret, RefreshToken: refreshToken, }, } updateSMTPSecrets(newConfigs, currentConfigs) assert.Equal(t, clientSecret, newConfigs.OAuth2.ClientSecret) assert.Equal(t, refreshToken, newConfigs.OAuth2.RefreshToken) } func TestOAuth2Redirect(t *testing.T) { server := httpdServer{} err := server.initializeRouter() require.NoError(t, err) rr := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state=invalid", nil) assert.NoError(t, err) server.handleOAuth2TokenRedirect(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nOAuth2ErrorTitle) ip := "127.1.1.4" tokenString := createOAuth2Token(server.csrfTokenAuth, xid.New().String(), ip) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state="+tokenString, nil) //nolint:goconst assert.NoError(t, err) req.RemoteAddr = ip server.handleOAuth2TokenRedirect(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nOAuth2ErrorValidateState) } func TestOAuth2Token(t *testing.T) { server := httpdServer{} err := server.initializeRouter() require.NoError(t, err) // invalid token _, err = verifyOAuth2Token(server.csrfTokenAuth, "token", "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to verify OAuth2 state") } // bad audience claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) tokenString, err := server.csrfTokenAuth.Sign(claims) assert.NoError(t, err) _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid OAuth2 state") } // bad IP tokenString = createOAuth2Token(server.csrfTokenAuth, "state", "127.1.1.1") _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.2") if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid OAuth2 state") } // ok state := xid.New().String() tokenString = createOAuth2Token(server.csrfTokenAuth, state, "127.1.1.3") s, err := verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.3") assert.NoError(t, err) assert.Equal(t, state, s) // no jti claims = jwt.NewClaims(tokenAudienceOAuth2, "127.1.1.4", getTokenDuration(tokenAudienceOAuth2)) tokenString, err = josejwt.Signed(server.csrfTokenAuth.Signer()).Claims(claims).Serialize() assert.NoError(t, err) _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.4") if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid OAuth2 state") } // encode error server.csrfTokenAuth.SetSigner(&failingJoseSigner{}) tokenString = createOAuth2Token(server.csrfTokenAuth, xid.New().String(), "") assert.Empty(t, tokenString) rr := httptest.NewRecorder() testReq := make(map[string]any) testReq["base_redirect_url"] = "http://localhost:8082" asJSON, err := json.Marshal(testReq) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) server.handleSMTPOAuth2TokenRequestPost(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Contains(t, rr.Body.String(), "unable to create state token") } func TestCSRFToken(t *testing.T) { server := httpdServer{} err := server.initializeRouter() require.NoError(t, err) // invalid token req := &http.Request{} err = verifyCSRFToken(req, server.csrfTokenAuth) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to verify form token") } // bad audience claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) tokenString, err := server.csrfTokenAuth.Sign(claims) assert.NoError(t, err) values := url.Values{} values.Set(csrfFormToken, tokenString) req.Form = values err = verifyCSRFToken(req, server.csrfTokenAuth) if assert.Error(t, err) { assert.Contains(t, err.Error(), "form token is not valid") } // bad IP req.RemoteAddr = "127.1.1.1" tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath) values.Set(csrfFormToken, tokenString) req.Form = values req.RemoteAddr = "127.1.1.2" err = verifyCSRFToken(req, server.csrfTokenAuth) if assert.Error(t, err) { assert.Contains(t, err.Error(), "form token is not valid") } claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) tokenString, err = server.csrfTokenAuth.Sign(claims) assert.NoError(t, err) assert.NotEmpty(t, tokenString) r, err := GetHTTPRouter(Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: true, EnableRESTAPI: true, RenderOpenAPI: true, }) assert.NoError(t, err) fn := server.verifyCSRFHeader(r) rr := httptest.NewRecorder() req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil) fn.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token") // invalid audience req.Header.Set(csrfHeaderToken, tokenString) rr = httptest.NewRecorder() fn.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), "the token is not valid") // invalid IP tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath) req.Header.Set(csrfHeaderToken, tokenString) req.RemoteAddr = "172.16.1.2" rr = httptest.NewRecorder() fn.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), "the token is not valid") csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) csrfTokenAuth.SetSigner(&failingJoseSigner{}) tokenString = createCSRFToken(httptest.NewRecorder(), req, csrfTokenAuth, "", webBaseAdminPath) assert.Empty(t, tokenString) rr = httptest.NewRecorder() createLoginCookie(rr, req, csrfTokenAuth, "", webBaseAdminPath, req.RemoteAddr) assert.Empty(t, rr.Header().Get("Set-Cookie")) } func TestCreateShareCookieError(t *testing.T) { username := "share_user" pwd := util.GenerateUniqueID() user := &dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: pwd, HomeDir: filepath.Join(os.TempDir(), username), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } err := dataprovider.AddUser(user, "", "", "") assert.NoError(t, err) share := &dataprovider.Share{ Name: "test_share_cookie_error", ShareID: util.GenerateUniqueID(), Scope: dataprovider.ShareScopeRead, Password: pwd, Paths: []string{"/"}, Username: username, } err = dataprovider.AddShare(share, "", "", "") assert.NoError(t, err) tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) tokenAuth.SetSigner(&failingJoseSigner{}) csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) server := httpdServer{ tokenAuth: tokenAuth, csrfTokenAuth: csrfTokenAuth, } c := jwt.NewClaims(tokenAudienceWebLogin, "127.0.0.1", getTokenDuration(tokenAudienceWebLogin)) token, err := server.csrfTokenAuth.Sign(c) assert.NoError(t, err) resp := c.BuildTokenResponse(token) parsedToken, err := jwt.VerifyToken(server.csrfTokenAuth, resp.Token) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, share.ShareID, "login"), nil) assert.NoError(t, err) req.RemoteAddr = "127.0.0.1:4567" ctx := req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form := make(url.Values) form.Set("share_password", pwd) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseClientPath)) rctx := chi.NewRouteContext() rctx.URLParams.Add("id", share.ShareID) rr := httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, share.ShareID, "login"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = "127.0.0.1:2345" req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req = req.WithContext(ctx) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) server.handleClientShareLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nError500Message) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) } func TestCreateTokenError(t *testing.T) { tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) tokenAuth.SetSigner(&failingJoseSigner{}) csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) server := httpdServer{ tokenAuth: tokenAuth, csrfTokenAuth: csrfTokenAuth, } rr := httptest.NewRecorder() admin := dataprovider.Admin{ Username: defaultAdminUsername, Password: "password", } req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) server.generateAndSendToken(rr, req, admin, "") assert.Equal(t, http.StatusInternalServerError, rr.Code) rr = httptest.NewRecorder() user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "u", Password: util.GenerateUniqueID(), }, } req, _ = http.NewRequest(http.MethodGet, userTokenPath, nil) server.generateAndSendUserToken(rr, req, "", user) assert.Equal(t, http.StatusInternalServerError, rr.Code) c := &jwt.Claims{} c.ID = xid.New().String() c.SetExpiry(time.Now().Add(1 * time.Minute)) tokenString, err := server.csrfTokenAuth.SignWithParams(c, tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) assert.NoError(t, err) token := c.BuildTokenResponse(tokenString) req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token.Token)) parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx := req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) rr = httptest.NewRecorder() form := make(url.Values) form.Set("username", admin.Username) form.Set("password", admin.Password) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, xid.New().String(), webBaseAdminPath)) cookie := rr.Header().Get("Set-Cookie") assert.NotEmpty(t, cookie) req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Cookie", cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) server.handleWebAdminLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) // req with no content type req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, nil) rr = httptest.NewRecorder() server.handleWebAdminLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) req, _ = http.NewRequest(http.MethodPost, webAdminSetupPath, nil) rr = httptest.NewRecorder() server.loginAdmin(rr, req, &admin, false, nil, "") // req with no POST body req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%AO%GG", nil) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%A1%G2", nil) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminChangePwdPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%A2%G3", nil) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") _, err = getAdminFromPostFields(req) assert.Error(t, err) req, _ = http.NewRequest(http.MethodPost, webAdminEventActionPath+"?a=a%C3%A2%GG", nil) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") _, err = getEventActionFromPostFields(req) assert.Error(t, err) req, _ = http.NewRequest(http.MethodPost, webAdminEventRulePath+"?a=a%C3%A3%GG", nil) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") _, err = getEventRuleFromPostFields(req) assert.Error(t, err) req, _ = http.NewRequest(http.MethodPost, webIPListPath+"/1?a=a%C3%AO%GG", nil) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") _, err = getIPListEntryFromPostFields(req, dataprovider.IPListTypeAllowList) assert.Error(t, err) req, _ = http.NewRequest(http.MethodPost, path.Join(webClientSharePath, "shareID", "login?a=a%C3%AO%GG"), bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleClientShareLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) req, _ = http.NewRequest(http.MethodPost, webClientLoginPath+"?a=a%C3%AO%GG", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath+"?a=a%C3%AO%GA", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientChangePwdPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath+"?a=a%C3%AO%GB", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientProfilePost(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath+"?a=a%C3%AO%GB", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminProfilePost(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) req, _ = http.NewRequest(http.MethodPost, webAdminTwoFactorPath+"?a=a%C3%AO%GC", bytes.NewBuffer([]byte(form.Encode()))) req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminTwoFactorPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath+"?a=a%C3%AO%GD", bytes.NewBuffer([]byte(form.Encode()))) req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminTwoFactorRecoveryPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webClientTwoFactorPath+"?a=a%C3%AO%GC", bytes.NewBuffer([]byte(form.Encode()))) req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientTwoFactorPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath+"?a=a%C3%AO%GD", bytes.NewBuffer([]byte(form.Encode()))) req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientTwoFactorRecoveryPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webAdminForgotPwdPath+"?a=a%C3%A1%GD", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminForgotPwdPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webClientForgotPwdPath+"?a=a%C2%A1%GD", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientForgotPwdPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webAdminResetPwdPath+"?a=a%C3%AO%JD", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminPasswordResetPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webAdminRolePath+"?a=a%C3%AO%JE", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAddRolePost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webClientResetPwdPath+"?a=a%C3%AO%JD", bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientPasswordResetPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath+"?a=a%K3%AO%GA", bytes.NewBuffer([]byte(form.Encode()))) _, err = getShareFromPostFields(req) if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid URL escape") } username := "webclientuser" user = dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: "clientpwd", HomeDir: filepath.Join(os.TempDir(), username), Status: 1, Description: "test user", }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} user.Filters.AllowAPIKeyAuth = true err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token.Token)) parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) rr = httptest.NewRecorder() form = make(url.Values) form.Set("username", user.Username) form.Set("password", "clientpwd") form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath)) req, _ = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.handleWebClientLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) err = authenticateUserWithAPIKey(username, "", server.tokenAuth, req) assert.Error(t, err) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.HomeDir) assert.NoError(t, err) admin.Username += "1" admin.Status = 1 admin.Filters.AllowAPIKeyAuth = true admin.Permissions = []string{dataprovider.PermAdminAny} err = dataprovider.AddAdmin(&admin, "", "", "") assert.NoError(t, err) err = authenticateAdminWithAPIKey(admin.Username, "", server.tokenAuth, req) assert.Error(t, err) err = dataprovider.DeleteAdmin(admin.Username, "", "", "") assert.NoError(t, err) } func TestAPIKeyAuthForbidden(t *testing.T) { r, err := GetHTTPRouter(Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: true, EnableRESTAPI: true, RenderOpenAPI: true, }) require.NoError(t, err) fn := forbidAPIKeyAuthentication(r) rr := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, versionPath, nil) fn.ServeHTTP(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") } func TestJWTTokenValidation(t *testing.T) { tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) claims := &jwt.Claims{ Username: defaultAdminUsername, } claims.SetExpiry(time.Now().UTC().Add(-1 * time.Hour)) _, err = tokenAuth.SignWithParams(claims, tokenAudienceWebAdmin, "", getTokenDuration(tokenAudienceWebAdmin)) require.NoError(t, err) server := httpdServer{ binding: Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: true, EnableRESTAPI: true, RenderOpenAPI: true, }, } err = server.initializeRouter() require.NoError(t, err) r := server.router fn := jwtAuthenticatorAPI(r) rr := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, userPath, nil) ctx := jwt.NewContext(req.Context(), claims, nil) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusUnauthorized, rr.Code) fn = jwtAuthenticatorWebAdmin(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) ctx = jwt.NewContext(req.Context(), claims, nil) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) fn = jwtAuthenticatorWebClient(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) ctx = jwt.NewContext(req.Context(), claims, nil) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) errTest := errors.New("test error") permFn := server.checkPerms(dataprovider.PermAdminAny) fn = permFn(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, userPath, nil) ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) permFn = server.checkPerms(dataprovider.PermAdminAny) fn = permFn(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) req.RequestURI = webUserPath ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) permClientFn := server.checkHTTPUserPerm(sdk.WebClientPubKeyChangeDisabled) fn = permClientFn(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, nil) req.RequestURI = webClientProfilePath ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, userProfilePath, nil) req.RequestURI = userProfilePath ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) fn = server.checkAuthRequirements(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, nil) req.RequestURI = webClientProfilePath ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) fn = server.checkAuthRequirements(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webGroupsPath, nil) req.RequestURI = webGroupsPath ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, userSharesPath, nil) req.RequestURI = userSharesPath ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) } func TestUpdateContextFromCookie(t *testing.T) { tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) server := httpdServer{ tokenAuth: tokenAuth, } req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) ctx := jwt.NewContext(req.Context(), claims, nil) req = server.updateContextFromCookie(req.WithContext(ctx)) token, err := jwt.FromContext(req.Context()) require.NoError(t, err) require.True(t, token.Audience.Contains(tokenAudienceWebClient)) require.NotEmpty(t, token.ID) } func TestCookieExpiration(t *testing.T) { tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) server := httpdServer{ tokenAuth: tokenAuth, } err = errors.New("test error") rr := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) ctx := jwt.NewContext(req.Context(), nil, err) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie := rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) admin := dataprovider.Admin{ Username: "newtestadmin", Password: "password", Permissions: []string{dataprovider.PermAdminAny}, } claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) claims.Username = admin.Username claims.Permissions = admin.Permissions claims.Subject = admin.GetSignature() claims.SetExpiry(time.Now().Add(1 * time.Minute)) _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) admin.Status = 0 err = dataprovider.AddAdmin(&admin, "", "", "") assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) admin.Status = 1 admin.Filters.AllowList = []string{"172.16.1.0/24"} err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) admin, err = dataprovider.AdminExists(admin.Username) assert.NoError(t, err) tokenID := xid.New().String() claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) claims.ID = tokenID claims.Username = admin.Username claims.Permissions = admin.Permissions claims.Subject = admin.GetSignature() claims.SetExpiry(time.Now().Add(1 * time.Minute)) _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) req.RemoteAddr = "192.168.8.1:1234" ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) req.RemoteAddr = "172.16.1.12:4567" ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.True(t, strings.HasPrefix(cookie, "jwt=")) req.Header.Set("Cookie", cookie) c, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) if assert.NoError(t, err) { assert.Equal(t, tokenID, c.ID) } err = dataprovider.DeleteAdmin(admin.Username, "", "", "") assert.NoError(t, err) // now check client cookie expiration username := "client" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: "clientpwd", HomeDir: filepath.Join(os.TempDir(), username), Status: 1, Description: "test user", }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{"*"} claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) claims.ID = tokenID claims.Username = user.Username claims.Permissions = user.Filters.WebClient claims.Subject = user.GetSignature() claims.SetExpiry(time.Now().Add(1 * time.Minute)) _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) // the password will be hashed and so the signature will change err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) user, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) user.Filters.AllowedIP = []string{"172.16.4.0/24"} err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) user, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) issuedAt := time.Now().Add(-1 * time.Minute) expiresAt := time.Now().Add(1 * time.Minute) claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) claims.ID = tokenID claims.Username = user.Username claims.Permissions = user.Filters.WebClient claims.Subject = user.GetSignature() claims.SetExpiry(expiresAt) claims.SetIssuedAt(issuedAt) _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = "172.16.3.12:4567" ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = "172.16.4.16:4567" ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.NotEmpty(t, cookie) req.Header.Set("Cookie", cookie) c, err = jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) if assert.NoError(t, err) { assert.Equal(t, tokenID, c.ID) assert.Equal(t, issuedAt.Unix(), c.IssuedAt.Time().Unix()) assert.NotEqual(t, expiresAt.Unix(), c.Expiry.Time().Unix()) } // test a cookie issued more that 12 hours ago claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) claims.ID = tokenID claims.Username = user.Username claims.Permissions = user.Filters.WebClient claims.Subject = user.GetSignature() claims.SetExpiry(expiresAt) claims.SetIssuedAt(time.Now().Add(-24 * time.Hour)) _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = "172.16.4.16:6789" ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) // test a disabled user user.Status = 0 err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) user, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) claims.ID = tokenID claims.Username = user.Username claims.Permissions = user.Filters.WebClient claims.Subject = user.GetSignature() claims.SetExpiry(time.Now().Add(1 * time.Minute)) claims.SetIssuedAt(issuedAt) _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) } func TestGetURLParam(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, adminPwdPath, nil) rctx := chi.NewRouteContext() rctx.URLParams.Add("val", "testuser%C3%A0") rctx.URLParams.Add("inval", "testuser%C3%AO%GG") req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) escaped := getURLParam(req, "val") assert.Equal(t, "testuserà", escaped) escaped = getURLParam(req, "inval") assert.Equal(t, "testuser%C3%AO%GG", escaped) } func TestChangePwdValidationErrors(t *testing.T) { err := doChangeAdminPassword(nil, "", "", "") require.Error(t, err) err = doChangeAdminPassword(nil, "a", "b", "c") require.Error(t, err) err = doChangeAdminPassword(nil, "a", "a", "a") require.Error(t, err) req, _ := http.NewRequest(http.MethodPut, adminPwdPath, nil) req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) err = doChangeAdminPassword(req, "currentpwd", "newpwd", "newpwd") assert.Error(t, err) } func TestRenderUnexistingFolder(t *testing.T) { rr := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, folderPath, nil) renderFolder(rr, req, "path not mapped", &jwt.Claims{}, http.StatusOK) assert.Equal(t, http.StatusNotFound, rr.Code) } func TestCloseConnectionHandler(t *testing.T) { tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) claims.Username = defaultAdminUsername claims.SetExpiry(time.Now().UTC().Add(1 * time.Hour)) _, err = tokenAuth.Sign(claims) assert.NoError(t, err) req, err := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) assert.NoError(t, err) rctx := chi.NewRouteContext() rctx.URLParams.Add("connectionID", "") req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req = req.WithContext(context.WithValue(req.Context(), jwt.TokenCtxKey, claims)) rr := httptest.NewRecorder() handleCloseConnection(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "connectionID is mandatory") } func TestRenderInvalidTemplate(t *testing.T) { tmpl, err := template.New("test").Parse("{{.Count}}") if assert.NoError(t, err) { noMatchTmpl := "no_match" adminTemplates[noMatchTmpl] = tmpl rw := httptest.NewRecorder() renderAdminTemplate(rw, noMatchTmpl, map[string]string{}) assert.Equal(t, http.StatusInternalServerError, rw.Code) clientTemplates[noMatchTmpl] = tmpl renderClientTemplate(rw, noMatchTmpl, map[string]string{}) assert.Equal(t, http.StatusInternalServerError, rw.Code) } } func TestQuotaScanInvalidFs(t *testing.T) { user := &dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test", HomeDir: os.TempDir(), }, FsConfig: vfs.Filesystem{ Provider: sdk.S3FilesystemProvider, }, } common.QuotaScans.AddUserQuotaScan(user.Username, "") err := doUserQuotaScan(user) assert.Error(t, err) } func TestVerifyTLSConnection(t *testing.T) { oldCertMgr := certMgr caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") certPath := filepath.Join(os.TempDir(), "testh.crt") keyPath := filepath.Join(os.TempDir(), "testh.key") err := os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(certPath, []byte(httpdCert), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(keyPath, []byte(httpdKey), os.ModePerm) assert.NoError(t, err) keyPairs := []common.TLSKeyPair{ { Cert: certPath, Key: keyPath, ID: common.DefaultTLSKeyPaidID, }, } certMgr, err = common.NewCertManager(keyPairs, "", "httpd_test") assert.NoError(t, err) certMgr.SetCARevocationLists([]string{caCrlPath}) err = certMgr.LoadCRLs() assert.NoError(t, err) crt, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) x509crt, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) server := httpdServer{} state := tls.ConnectionState{ PeerCertificates: []*x509.Certificate{x509crt}, } err = server.verifyTLSConnection(state) assert.Error(t, err) // no verified certification chain crt, err = tls.X509KeyPair([]byte(caCRT), []byte(caKey)) assert.NoError(t, err) x509CAcrt, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crt, x509CAcrt}) err = server.verifyTLSConnection(state) assert.NoError(t, err) crt, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) assert.NoError(t, err) x509crtRevoked, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crtRevoked, x509CAcrt}) state.PeerCertificates = []*x509.Certificate{x509crtRevoked} err = server.verifyTLSConnection(state) assert.EqualError(t, err, common.ErrCrtRevoked.Error()) err = os.Remove(caCrlPath) assert.NoError(t, err) err = os.Remove(certPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) certMgr = oldCertMgr } func TestGetFolderFromTemplate(t *testing.T) { folder := vfs.BaseVirtualFolder{ MappedPath: "Folder%name%", Description: "Folder %name% desc", } folderName := "folderTemplate" folderTemplate := getFolderFromTemplate(folder, folderName) require.Equal(t, folderName, folderTemplate.Name) require.Equal(t, fmt.Sprintf("Folder%v", folderName), folderTemplate.MappedPath) require.Equal(t, fmt.Sprintf("Folder %v desc", folderName), folderTemplate.Description) folder.FsConfig.Provider = sdk.CryptedFilesystemProvider folder.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("%name%") folderTemplate = getFolderFromTemplate(folder, folderName) require.Equal(t, folderName, folderTemplate.FsConfig.CryptConfig.Passphrase.GetPayload()) folder.FsConfig.Provider = sdk.GCSFilesystemProvider folder.FsConfig.GCSConfig.KeyPrefix = "prefix%name%/" folderTemplate = getFolderFromTemplate(folder, folderName) require.Equal(t, fmt.Sprintf("prefix%v/", folderName), folderTemplate.FsConfig.GCSConfig.KeyPrefix) folder.FsConfig.Provider = sdk.AzureBlobFilesystemProvider folder.FsConfig.AzBlobConfig.KeyPrefix = "a%name%" folder.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("pwd%name%") folderTemplate = getFolderFromTemplate(folder, folderName) require.Equal(t, "a"+folderName, folderTemplate.FsConfig.AzBlobConfig.KeyPrefix) require.Equal(t, "pwd"+folderName, folderTemplate.FsConfig.AzBlobConfig.AccountKey.GetPayload()) folder.FsConfig.Provider = sdk.SFTPFilesystemProvider folder.FsConfig.SFTPConfig.Prefix = "%name%" folder.FsConfig.SFTPConfig.Username = "sftp_%name%" folder.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp%name%") folderTemplate = getFolderFromTemplate(folder, folderName) require.Equal(t, folderName, folderTemplate.FsConfig.SFTPConfig.Prefix) require.Equal(t, "sftp_"+folderName, folderTemplate.FsConfig.SFTPConfig.Username) require.Equal(t, "sftp"+folderName, folderTemplate.FsConfig.SFTPConfig.Password.GetPayload()) } func TestGetUserFromTemplate(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Status: 1, }, } user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: "Folder%username%", }, }) username := "userTemplate" password := "pwdTemplate" templateFields := userTemplateFields{ Username: username, Password: password, } userTemplate := getUserFromTemplate(user, templateFields) require.Len(t, userTemplate.VirtualFolders, 1) require.Equal(t, "Folder"+username, userTemplate.VirtualFolders[0].Name) user.FsConfig.Provider = sdk.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("%password%") userTemplate = getUserFromTemplate(user, templateFields) require.Equal(t, password, userTemplate.FsConfig.CryptConfig.Passphrase.GetPayload()) user.FsConfig.Provider = sdk.GCSFilesystemProvider user.FsConfig.GCSConfig.KeyPrefix = "%username%%password%" userTemplate = getUserFromTemplate(user, templateFields) require.Equal(t, username+password, userTemplate.FsConfig.GCSConfig.KeyPrefix) user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.KeyPrefix = "a%username%" user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("pwd%password%%username%") userTemplate = getUserFromTemplate(user, templateFields) require.Equal(t, "a"+username, userTemplate.FsConfig.AzBlobConfig.KeyPrefix) require.Equal(t, "pwd"+password+username, userTemplate.FsConfig.AzBlobConfig.AccountKey.GetPayload()) user.FsConfig.Provider = sdk.SFTPFilesystemProvider user.FsConfig.SFTPConfig.Prefix = "%username%" user.FsConfig.SFTPConfig.Username = "sftp_%username%" user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp%password%") userTemplate = getUserFromTemplate(user, templateFields) require.Equal(t, username, userTemplate.FsConfig.SFTPConfig.Prefix) require.Equal(t, "sftp_"+username, userTemplate.FsConfig.SFTPConfig.Username) require.Equal(t, "sftp"+password, userTemplate.FsConfig.SFTPConfig.Password.GetPayload()) } func TestJWTTokenCleanup(t *testing.T) { tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) server := httpdServer{ tokenAuth: tokenAuth, } admin := dataprovider.Admin{ Username: "newtestadmin", Password: "password", Permissions: []string{dataprovider.PermAdminAny}, } claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) claims.Username = admin.Username claims.Permissions = admin.Permissions claims.Subject = admin.GetSignature() claims.SetExpiry(time.Now().Add(1 * time.Minute)) token, err := server.tokenAuth.Sign(claims) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, versionPath, nil) assert.True(t, isTokenInvalidated(req)) fakeToken := "abc" invalidateTokenString(req, fakeToken, -100*time.Millisecond) assert.True(t, invalidatedJWTTokens.Get(fakeToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) invalidatedJWTTokens.Add(token, time.Now().Add(-getTokenDuration(tokenAudienceWebAdmin)).UTC()) require.True(t, isTokenInvalidated(req)) startCleanupTicker(100 * time.Millisecond) assert.Eventually(t, func() bool { return !isTokenInvalidated(req) }, 1*time.Second, 200*time.Millisecond) assert.False(t, invalidatedJWTTokens.Get(fakeToken)) stopCleanupTicker() } func TestDbTokenManager(t *testing.T) { if !isSharedProviderSupported() { t.Skip("this test it is not available with this provider") } mgr := newTokenManager(1) dbTokenManager := mgr.(*dbTokenManager) testToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsiV2ViQWRtaW4iLCI6OjEiXSwiZXhwIjoxNjk4NjYwMDM4LCJqdGkiOiJja3ZuazVrYjF1aHUzZXRmZmhyZyIsIm5iZiI6MTY5ODY1ODgwOCwicGVybWlzc2lvbnMiOlsiKiJdLCJzdWIiOiIxNjk3ODIwNDM3NTMyIiwidXNlcm5hbWUiOiJhZG1pbiJ9.LXuFFksvnSuzHqHat6r70yR0jEulNRju7m7SaWrOfy8; csrftoken=mP0C7DqjwpAXsptO2gGCaYBkYw3oNMWB" key := dbTokenManager.getKey(testToken) require.Len(t, key, 64) dbTokenManager.Add(testToken, time.Now().Add(-getTokenDuration(tokenAudienceWebClient)).UTC()) isInvalidated := dbTokenManager.Get(testToken) assert.True(t, isInvalidated) dbTokenManager.Cleanup() isInvalidated = dbTokenManager.Get(testToken) assert.False(t, isInvalidated) dbTokenManager.Add(testToken, time.Now().Add(getTokenDuration(tokenAudienceWebAdmin)).UTC()) isInvalidated = dbTokenManager.Get(testToken) assert.True(t, isInvalidated) dbTokenManager.Cleanup() isInvalidated = dbTokenManager.Get(testToken) assert.True(t, isInvalidated) err := dataprovider.DeleteSharedSession(key, dataprovider.SessionTypeInvalidToken) assert.NoError(t, err) } func TestDatabaseSharedSessions(t *testing.T) { if !isSharedProviderSupported() { t.Skip("this test it is not available with this provider") } session1 := dataprovider.Session{ Key: "1", Data: map[string]string{"a": "b"}, Type: dataprovider.SessionTypeOIDCAuth, Timestamp: 10, } err := dataprovider.AddSharedSession(session1) assert.NoError(t, err) // Adding another session with the same key but a different type should work session2 := session1 session2.Type = dataprovider.SessionTypeOIDCToken err = dataprovider.AddSharedSession(session2) assert.NoError(t, err) err = dataprovider.DeleteSharedSession(session1.Key, dataprovider.SessionTypeInvalidToken) assert.ErrorIs(t, err, util.ErrNotFound) _, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeResetCode) assert.ErrorIs(t, err, util.ErrNotFound) session1Get, err := dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) assert.NoError(t, err) assert.Equal(t, session1.Timestamp, session1Get.Timestamp) var stored map[string]string err = json.Unmarshal(session1Get.Data.([]byte), &stored) assert.NoError(t, err) assert.Equal(t, session1.Data, stored) session1.Timestamp = 20 session1.Data = map[string]string{"c": "d"} err = dataprovider.AddSharedSession(session1) assert.NoError(t, err) session1Get, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) assert.NoError(t, err) assert.Equal(t, session1.Timestamp, session1Get.Timestamp) stored = make(map[string]string) err = json.Unmarshal(session1Get.Data.([]byte), &stored) assert.NoError(t, err) assert.Equal(t, session1.Data, stored) err = dataprovider.DeleteSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) assert.NoError(t, err) err = dataprovider.DeleteSharedSession(session2.Key, dataprovider.SessionTypeOIDCToken) assert.NoError(t, err) _, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) assert.ErrorIs(t, err, util.ErrNotFound) _, err = dataprovider.GetSharedSession(session2.Key, dataprovider.SessionTypeOIDCToken) assert.ErrorIs(t, err, util.ErrNotFound) } func TestAllowedProxyUnixDomainSocket(t *testing.T) { b := Binding{ Address: filepath.Join(os.TempDir(), "sock"), ProxyAllowed: []string{"127.0.0.1", "127.0.1.1"}, } err := b.parseAllowedProxy() assert.NoError(t, err) if assert.Len(t, b.allowHeadersFrom, 1) { assert.True(t, b.allowHeadersFrom[0](nil)) } } func TestProxyListenerWrapper(t *testing.T) { b := Binding{ ProxyMode: 0, } require.Nil(t, b.listenerWrapper()) b.ProxyMode = 1 require.NotNil(t, b.listenerWrapper()) } func TestProxyHeaders(t *testing.T) { username := "adminTest" password := "testPwd" admin := dataprovider.Admin{ Username: username, Password: password, Permissions: []string{dataprovider.PermAdminAny}, Status: 1, Filters: dataprovider.AdminFilters{ AllowList: []string{"172.19.2.0/24"}, }, } err := dataprovider.AddAdmin(&admin, "", "", "") assert.NoError(t, err) testIP := "10.29.1.9" validForwardedFor := "172.19.2.6" b := Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: false, EnableRESTAPI: true, ProxyAllowed: []string{testIP, "10.8.0.0/30"}, ClientIPProxyHeader: "x-forwarded-for", } err = b.parseAllowedProxy() assert.NoError(t, err) server := newHttpdServer(b, "", "", CorsConfig{Enabled: true}, "") err = server.initializeRouter() require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() req, err := http.NewRequest(http.MethodGet, tokenPath, nil) assert.NoError(t, err) req.Header.Set("X-Forwarded-For", validForwardedFor) req.Header.Set(xForwardedProto, "https") req.RemoteAddr = "127.0.0.1:123" req.SetBasicAuth(username, password) rr := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) assert.NotContains(t, rr.Body.String(), "login from IP 127.0.0.1 not allowed") req.RemoteAddr = testIP rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) req.RemoteAddr = "10.8.0.2" rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) req.RemoteAddr = testIP rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) cookie := rr.Header().Get("Set-Cookie") assert.NotEmpty(t, cookie) req.Header.Set("Cookie", cookie) parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx := req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form := make(url.Values) form.Set("username", username) form.Set("password", password) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP req.Header.Set("Cookie", cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) req.RemoteAddr = validForwardedFor rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) loginCookie := rr.Header().Get("Set-Cookie") assert.NotEmpty(t, loginCookie) req.Header.Set("Cookie", loginCookie) parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("X-Forwarded-For", validForwardedFor) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) cookie = rr.Header().Get("Set-Cookie") assert.NotContains(t, cookie, "Secure") // The login cookie is invalidated after a successful login, the same request will fail req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("X-Forwarded-For", validForwardedFor) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) req.RemoteAddr = validForwardedFor rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) loginCookie = rr.Header().Get("Set-Cookie") assert.NotEmpty(t, loginCookie) req.Header.Set("Cookie", loginCookie) parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("X-Forwarded-For", validForwardedFor) req.Header.Set(xForwardedProto, "https") rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) cookie = rr.Header().Get("Set-Cookie") assert.Contains(t, cookie, "Secure") req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) req.RemoteAddr = validForwardedFor rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) loginCookie = rr.Header().Get("Set-Cookie") assert.NotEmpty(t, loginCookie) req.Header.Set("Cookie", loginCookie) parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("X-Forwarded-For", validForwardedFor) req.Header.Set(xForwardedProto, "http") rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) cookie = rr.Header().Get("Set-Cookie") assert.NotContains(t, cookie, "Secure") err = dataprovider.DeleteAdmin(username, "", "", "") assert.NoError(t, err) } func TestRecoverer(t *testing.T) { recoveryPath := "/recovery" b := Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: false, EnableRESTAPI: true, } server := newHttpdServer(b, "../static", "", CorsConfig{}, "../openapi") err := server.initializeRouter() require.NoError(t, err) server.router.Get(recoveryPath, func(_ http.ResponseWriter, _ *http.Request) { panic("panic") }) testServer := httptest.NewServer(server.router) defer testServer.Close() req, err := http.NewRequest(http.MethodGet, recoveryPath, nil) assert.NoError(t, err) rr := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) server.router = chi.NewRouter() server.router.Use(middleware.Recoverer) server.router.Get(recoveryPath, func(_ http.ResponseWriter, _ *http.Request) { panic("panic") }) testServer = httptest.NewServer(server.router) defer testServer.Close() req, err = http.NewRequest(http.MethodGet, recoveryPath, nil) assert.NoError(t, err) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) } func TestStreamJSONArray(t *testing.T) { dataGetter := func(_, _ int) ([]byte, int, error) { return nil, 0, nil } rr := httptest.NewRecorder() streamJSONArray(rr, 10, dataGetter) assert.Equal(t, `[]`, rr.Body.String()) data := []int{} for i := 0; i < 10; i++ { data = append(data, i) } dataGetter = func(_, offset int) ([]byte, int, error) { if offset >= len(data) { return nil, 0, nil } val := data[offset] data, err := json.Marshal([]int{val}) return data, 1, err } rr = httptest.NewRecorder() streamJSONArray(rr, 1, dataGetter) assert.Equal(t, `[0,1,2,3,4,5,6,7,8,9]`, rr.Body.String()) } func TestCompressorAbortHandler(t *testing.T) { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() connection := newConnection( common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", dataprovider.User{}), nil, nil, ) share := &dataprovider.Share{} renderCompressedFiles(&failingWriter{}, connection, "", nil, share) } func TestStreamDataAbortHandler(t *testing.T) { defer func() { rcv := recover() assert.Equal(t, http.ErrAbortHandler, rcv) }() streamData(&failingWriter{}, []byte(`["a":"b"]`)) } func TestZipErrors(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} connection := newConnection( common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), nil, nil, ) testDir := filepath.Join(os.TempDir(), "testDir") err := os.MkdirAll(testDir, os.ModePerm) assert.NoError(t, err) wr := zip.NewWriter(&failingWriter{}) err = wr.Close() if assert.Error(t, err) { assert.Contains(t, err.Error(), "write error") } err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", nil, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "write error") } err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", nil, 2000) assert.ErrorIs(t, err, util.ErrRecursionTooDeep) err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), path.Join("/", filepath.Base(testDir), "dir"), nil, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is outside base dir") } testFilePath := filepath.Join(testDir, "ziptest.zip") err = os.WriteFile(testFilePath, util.GenerateRandomBytes(65535), os.ModePerm) assert.NoError(t, err) err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)), "/"+filepath.Base(testDir), nil, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "write error") } connection.User.Permissions["/"] = []string{dataprovider.PermListItems} err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)), "/"+filepath.Base(testDir), nil, 0) assert.ErrorIs(t, err, os.ErrPermission) // creating a virtual folder to a missing path stat is ok but readdir fails user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped"), }, VirtualPath: "/vpath", }) connection.User = user wr = zip.NewWriter(bytes.NewBuffer(make([]byte, 0))) err = addZipEntry(wr, connection, user.VirtualFolders[0].VirtualPath, "/", nil, 0) assert.Error(t, err) user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ Path: "/", DeniedPatterns: []string{"*.zip"}, }) err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", nil, 0) assert.ErrorIs(t, err, os.ErrPermission) err = os.RemoveAll(testDir) assert.NoError(t, err) } func TestWebAdminRedirect(t *testing.T) { b := Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: false, EnableRESTAPI: true, } server := newHttpdServer(b, "../static", "", CorsConfig{}, "../openapi") err := server.initializeRouter() require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() req, err := http.NewRequest(http.MethodGet, webRootPath, nil) assert.NoError(t, err) rr := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) req, err = http.NewRequest(http.MethodGet, webBasePath, nil) assert.NoError(t, err) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) } func TestParseRangeRequests(t *testing.T) { // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=24-24" fileSize := int64(169740) rangeHeader := "bytes=24-24" offset, size, err := parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp := fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 24-24/169740", resp) require.Equal(t, int64(1), size) // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=24-" rangeHeader = "bytes=24-" offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 24-169739/169740", resp) require.Equal(t, int64(169716), size) // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=-1" rangeHeader = "bytes=-1" offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 169739-169739/169740", resp) require.Equal(t, int64(1), size) // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=-100" rangeHeader = "bytes=-100" offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 169640-169739/169740", resp) require.Equal(t, int64(100), size) // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=20-30" rangeHeader = "bytes=20-30" offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 20-30/169740", resp) require.Equal(t, int64(11), size) // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=20-169739" rangeHeader = "bytes=20-169739" offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 20-169739/169740", resp) require.Equal(t, int64(169720), size) // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=20-169740" rangeHeader = "bytes=20-169740" offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 20-169739/169740", resp) require.Equal(t, int64(169720), size) // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=20-169741" rangeHeader = "bytes=20-169741" offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 20-169739/169740", resp) require.Equal(t, int64(169720), size) //curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=0-" > /dev/null rangeHeader = "bytes=0-" offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) require.NoError(t, err) resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) assert.Equal(t, "bytes 0-169739/169740", resp) require.Equal(t, int64(169740), size) // now test errors rangeHeader = "bytes=0-a" _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) require.Error(t, err) rangeHeader = "bytes=" _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) require.Error(t, err) rangeHeader = "bytes=-" _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) require.Error(t, err) rangeHeader = "bytes=500-300" _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) require.Error(t, err) rangeHeader = "bytes=5000000" _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) require.Error(t, err) } func TestRequestHeaderErrors(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.Header.Set("If-Unmodified-Since", "not a date") res := checkIfUnmodifiedSince(req, time.Now()) assert.Equal(t, condNone, res) req, _ = http.NewRequest(http.MethodPost, webClientFilesPath, nil) res = checkIfModifiedSince(req, time.Now()) assert.Equal(t, condNone, res) req, _ = http.NewRequest(http.MethodPost, webClientFilesPath, nil) res = checkIfRange(req, time.Now()) assert.Equal(t, condNone, res) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.Header.Set("If-Modified-Since", "not a date") res = checkIfModifiedSince(req, time.Now()) assert.Equal(t, condNone, res) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.Header.Set("If-Range", time.Now().Format(http.TimeFormat)) res = checkIfRange(req, time.Time{}) assert.Equal(t, condFalse, res) req.Header.Set("If-Range", "invalid if range date") res = checkIfRange(req, time.Now()) assert.Equal(t, condFalse, res) modTime := getFileObjectModTime(time.Time{}) assert.Empty(t, modTime) } func TestConnection(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test_httpd_user", HomeDir: filepath.Clean(os.TempDir()), }, FsConfig: vfs.Filesystem{ Provider: sdk.GCSFilesystemProvider, GCSConfig: vfs.GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ Bucket: "test_bucket_name", }, Credentials: kms.NewPlainSecret("invalid JSON payload"), }, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} connection := newConnection( common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), nil, nil, ) assert.Empty(t, connection.GetClientVersion()) assert.Empty(t, connection.GetRemoteAddress()) assert.Empty(t, connection.GetCommand()) name := "missing file name" _, err := connection.getFileReader(name, 0, http.MethodGet) assert.Error(t, err) connection.User.FsConfig.Provider = sdk.LocalFilesystemProvider _, err = connection.getFileReader(name, 0, http.MethodGet) assert.ErrorIs(t, err, os.ErrNotExist) } func TestGetFileWriterErrors(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test_httpd_user", HomeDir: "invalid", }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} connection := newConnection( common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), nil, nil, ) _, err := connection.getFileWriter("name") assert.Error(t, err) user.FsConfig.Provider = sdk.S3FilesystemProvider user.FsConfig.S3Config = vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: "b", Region: "us-west-1", AccessKey: "key", }, AccessSecret: kms.NewPlainSecret("secret"), } connection = newConnection( common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), nil, nil, ) _, err = connection.getFileWriter("/path") assert.Error(t, err) } func TestThrottledHandler(t *testing.T) { tr := &throttledReader{ r: io.NopCloser(bytes.NewBuffer(nil)), } assert.Equal(t, int64(0), tr.GetTruncatedSize()) err := tr.Close() assert.NoError(t, err) assert.Empty(t, tr.GetRealFsPath("real path")) assert.False(t, tr.SetTimes("p", time.Now(), time.Now())) _, err = tr.Truncate("", 0) assert.ErrorIs(t, err, vfs.ErrVfsUnsupported) err = tr.GetAbortError() assert.ErrorIs(t, err, common.ErrTransferAborted) } func TestHTTPDFile(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test_httpd_user", HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} connection := newConnection( common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), nil, nil, ) fs, err := user.GetFilesystem("") assert.NoError(t, err) name := "fileName" p := filepath.Join(os.TempDir(), name) err = os.WriteFile(p, []byte("contents"), os.ModePerm) assert.NoError(t, err) file, err := os.Open(p) assert.NoError(t, err) err = file.Close() assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) httpdFile := newHTTPDFile(baseTransfer, nil, nil) // the file is closed, read should fail buf := make([]byte, 100) _, err = httpdFile.Read(buf) assert.Error(t, err) err = httpdFile.Close() assert.Error(t, err) err = httpdFile.Close() assert.ErrorIs(t, err, common.ErrTransferClosed) err = os.Remove(p) assert.NoError(t, err) httpdFile.writer = file httpdFile.File = nil httpdFile.ErrTransfer = nil err = httpdFile.closeIO() assert.Error(t, err) assert.Error(t, httpdFile.ErrTransfer) assert.Equal(t, err, httpdFile.ErrTransfer) httpdFile.SignalClose(nil) _, err = httpdFile.Write(nil) assert.ErrorIs(t, err, common.ErrQuotaExceeded) } func TestChangeUserPwd(t *testing.T) { req, _ := http.NewRequest(http.MethodPost, webChangeClientPwdPath, nil) err := doChangeUserPassword(req, "", "", "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "please provide the current password and the new one two times") } err = doChangeUserPassword(req, "a", "b", "c") if assert.Error(t, err) { assert.Contains(t, err.Error(), "the two password fields do not match") } err = doChangeUserPassword(req, "a", "b", "b") if assert.Error(t, err) { assert.Contains(t, err.Error(), errInvalidTokenClaims.Error()) } } func TestWebUserInvalidClaims(t *testing.T) { server := httpdServer{} err := server.initializeRouter() require.NoError(t, err) rr := httptest.NewRecorder() user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "", Password: "pwd", }, } c := &jwt.Claims{ Username: user.Username, Permissions: nil, } c.Subject = user.GetSignature() c.SetExpiry(time.Now().Add(10 * time.Minute)) c.Audience = []string{tokenAudienceAPI} token, err := server.tokenAuth.Sign(c) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientGetFiles(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientDirsPath, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientGetDirContents(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorDirList403) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientDownloadZipPath, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleWebClientDownloadZip(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientEditFilePath, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientEditFile(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientSharePath, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientAddShareGet(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientSharePath, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientUpdateShareGet(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webClientSharePath, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientAddSharePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webClientSharePath+"/id", nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientUpdateSharePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientSharesPath+jsonAPISuffix, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) getAllShares(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientViewPDFPath, nil) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientGetPDF(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) } func TestInvalidClaims(t *testing.T) { server := httpdServer{} err := server.initializeRouter() require.NoError(t, err) rr := httptest.NewRecorder() user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "", Password: "pwd", }, } c := &jwt.Claims{ Username: user.Username, Permissions: nil, } c.Subject = user.GetSignature() token, err := server.tokenAuth.SignWithParams(c, tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientProfilePath, nil) assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) parsedToken, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx := req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form := make(url.Values) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath)) form.Set("public_keys", "") req, err = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleWebClientProfilePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) admin := dataprovider.Admin{ Username: "", Password: user.Password, } c = &jwt.Claims{ Username: admin.Username, Permissions: nil, } c.Subject = admin.GetSignature() token, err = server.tokenAuth.SignWithParams(c, tokenAudienceWebAdmin, "", getTokenDuration(tokenAudienceWebAdmin)) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil) assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) parsedToken, err = jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form = make(url.Values) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath)) form.Set("allow_api_key_auth", "") req, err = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleWebAdminProfilePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) } func TestTLSReq(t *testing.T) { req, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil) assert.NoError(t, err) req.TLS = &tls.ConnectionState{} assert.True(t, isTLS(req)) req.TLS = nil ctx := context.WithValue(req.Context(), forwardedProtoKey, "https") assert.True(t, isTLS(req.WithContext(ctx))) ctx = context.WithValue(req.Context(), forwardedProtoKey, "http") assert.False(t, isTLS(req.WithContext(ctx))) assert.Equal(t, "context value forwarded proto", forwardedProtoKey.String()) } func TestSigningKey(t *testing.T) { signingPassphrase := "test" server1 := httpdServer{ signingPassphrase: signingPassphrase, } err := server1.initializeRouter() require.NoError(t, err) server2 := httpdServer{ signingPassphrase: signingPassphrase, } err = server2.initializeRouter() require.NoError(t, err) user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "", Password: "pwd", }, } c := &jwt.Claims{ Username: user.Username, Permissions: nil, } c.Subject = user.GetSignature() token, err := server1.tokenAuth.SignWithParams(c, tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) assert.NoError(t, err) assert.NotEmpty(t, token) _, err = jwt.VerifyToken(server1.tokenAuth, token) assert.NoError(t, err) _, err = jwt.VerifyToken(server2.tokenAuth, token) assert.NoError(t, err) } func TestLoginLinks(t *testing.T) { b := Binding{ EnableWebAdmin: true, EnableWebClient: false, EnableRESTAPI: true, } assert.False(t, b.showClientLoginURL()) b = Binding{ EnableWebAdmin: false, EnableWebClient: true, EnableRESTAPI: true, } assert.False(t, b.showAdminLoginURL()) b = Binding{ EnableWebAdmin: true, EnableWebClient: true, EnableRESTAPI: true, } assert.True(t, b.showAdminLoginURL()) assert.True(t, b.showClientLoginURL()) b.HideLoginURL = 3 assert.False(t, b.showAdminLoginURL()) assert.False(t, b.showClientLoginURL()) b.HideLoginURL = 1 assert.True(t, b.showAdminLoginURL()) assert.False(t, b.showClientLoginURL()) b.HideLoginURL = 2 assert.False(t, b.showAdminLoginURL()) assert.True(t, b.showClientLoginURL()) } func TestResetCodesCleanup(t *testing.T) { resetCode := newResetCode(util.GenerateUniqueID(), false) resetCode.ExpiresAt = time.Now().Add(-1 * time.Minute).UTC() err := resetCodesMgr.Add(resetCode) assert.NoError(t, err) resetCodesMgr.Cleanup() _, err = resetCodesMgr.Get(resetCode.Code) assert.Error(t, err) } func TestUserCanResetPassword(t *testing.T) { req, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil) assert.NoError(t, err) req.RemoteAddr = "172.16.9.2:55080" u := dataprovider.User{} assert.True(t, isUserAllowedToResetPassword(req, &u)) u.Filters.DeniedProtocols = []string{common.ProtocolHTTP} assert.False(t, isUserAllowedToResetPassword(req, &u)) u.Filters.DeniedProtocols = nil u.Filters.WebClient = []string{sdk.WebClientPasswordResetDisabled} assert.False(t, isUserAllowedToResetPassword(req, &u)) u.Filters.WebClient = nil u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} assert.False(t, isUserAllowedToResetPassword(req, &u)) u.Filters.DeniedLoginMethods = nil u.Filters.AllowedIP = []string{"127.0.0.1/8"} assert.False(t, isUserAllowedToResetPassword(req, &u)) } func TestBrowsableSharePaths(t *testing.T) { share := dataprovider.Share{ Paths: []string{"/"}, Username: defaultAdminUsername, } _, err := getUserForShare(share) if assert.Error(t, err) { assert.ErrorIs(t, err, util.ErrNotFound) } req, err := http.NewRequest(http.MethodGet, "/share", nil) require.NoError(t, err) name, err := getBrowsableSharedPath(share.Paths[0], req) assert.NoError(t, err) assert.Equal(t, "/", name) req, err = http.NewRequest(http.MethodGet, "/share?path=abc", nil) require.NoError(t, err) name, err = getBrowsableSharedPath(share.Paths[0], req) assert.NoError(t, err) assert.Equal(t, "/abc", name) share.Paths = []string{"/a/b/c"} req, err = http.NewRequest(http.MethodGet, "/share?path=abc", nil) require.NoError(t, err) name, err = getBrowsableSharedPath(share.Paths[0], req) assert.NoError(t, err) assert.Equal(t, "/a/b/c/abc", name) req, err = http.NewRequest(http.MethodGet, "/share?path=%2Fabc/d", nil) require.NoError(t, err) name, err = getBrowsableSharedPath(share.Paths[0], req) assert.NoError(t, err) assert.Equal(t, "/a/b/c/abc/d", name) req, err = http.NewRequest(http.MethodGet, "/share?path=%2Fabc%2F..%2F..", nil) require.NoError(t, err) _, err = getBrowsableSharedPath(share.Paths[0], req) assert.Error(t, err) req, err = http.NewRequest(http.MethodGet, "/share?path=%2Fabc%2F..", nil) require.NoError(t, err) name, err = getBrowsableSharedPath(share.Paths[0], req) assert.NoError(t, err) assert.Equal(t, "/a/b/c", name) share = dataprovider.Share{ Paths: []string{"/a", "/b"}, } } func TestSecureMiddlewareIntegration(t *testing.T) { forwardedHostHeader := "X-Forwarded-Host" server := httpdServer{ binding: Binding{ ProxyAllowed: []string{"192.168.1.0/24"}, Security: SecurityConf{ Enabled: true, AllowedHosts: []string{"*.sftpgo.com"}, AllowedHostsAreRegex: true, HostsProxyHeaders: []string{forwardedHostHeader}, HTTPSProxyHeaders: []HTTPSProxyHeader{ { Key: xForwardedProto, Value: "https", }, }, STSSeconds: 31536000, STSIncludeSubdomains: true, STSPreload: true, ContentTypeNosniff: true, CacheControl: "private", CrossOriginOpenerPolicy: "same-origin", CrossOriginResourcePolicy: "same-site", CrossOriginEmbedderPolicy: "require-corp", ReferrerPolicy: "no-referrer", }, }, enableWebAdmin: true, enableWebClient: true, enableRESTAPI: true, } server.binding.Security.updateProxyHeaders() err := server.binding.parseAllowedProxy() assert.NoError(t, err) assert.Equal(t, []string{forwardedHostHeader, xForwardedProto}, server.binding.Security.proxyHeaders) assert.Equal(t, map[string]string{xForwardedProto: "https"}, server.binding.Security.getHTTPSProxyHeaders()) err = server.initializeRouter() require.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil) assert.NoError(t, err) r.Host = "127.0.0.1" server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Equal(t, "no-cache, no-store, max-age=0, must-revalidate, private", rr.Header().Get("Cache-Control")) rr = httptest.NewRecorder() r.Header.Set(forwardedHostHeader, "www.sftpgo.com") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusForbidden, rr.Code) // the header should be removed assert.Empty(t, r.Header.Get(forwardedHostHeader)) rr = httptest.NewRecorder() r.Host = "test.sftpgo.com" r.Header.Set(forwardedHostHeader, "test.example.com") r.RemoteAddr = "192.168.1.1" server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusForbidden, rr.Code) assert.NotEmpty(t, r.Header.Get(forwardedHostHeader)) rr = httptest.NewRecorder() r.Header.Set(forwardedHostHeader, "www.sftpgo.com") r.RemoteAddr = "192.168.1.1" server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) assert.NotEmpty(t, r.Header.Get(forwardedHostHeader)) assert.Empty(t, rr.Header().Get("Strict-Transport-Security")) assert.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) // now set the X-Forwarded-Proto to https, we should get the Strict-Transport-Security header rr = httptest.NewRecorder() r.Host = "test.sftpgo.com" r.Header.Set(xForwardedProto, "https") r.RemoteAddr = "192.168.1.3" server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) assert.NotEmpty(t, r.Header.Get(forwardedHostHeader)) assert.Equal(t, "max-age=31536000; includeSubDomains; preload", rr.Header().Get("Strict-Transport-Security")) assert.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) assert.Equal(t, "require-corp", rr.Header().Get("Cross-Origin-Embedder-Policy")) assert.Equal(t, "same-origin", rr.Header().Get("Cross-Origin-Opener-Policy")) assert.Equal(t, "same-site", rr.Header().Get("Cross-Origin-Resource-Policy")) assert.Equal(t, "no-referrer", rr.Header().Get("Referrer-Policy")) server.binding.Security.Enabled = false server.binding.Security.updateProxyHeaders() assert.Len(t, server.binding.Security.proxyHeaders, 0) } func TestGetCompressedFileName(t *testing.T) { username := "test" res := getCompressedFileName(username, []string{"single dir"}) require.Equal(t, fmt.Sprintf("%s-single dir.zip", username), res) res = getCompressedFileName(username, []string{"file1", "file2"}) require.Equal(t, fmt.Sprintf("%s-download.zip", username), res) res = getCompressedFileName(username, []string{"file1.txt"}) require.Equal(t, fmt.Sprintf("%s-file1.zip", username), res) // now files with full paths res = getCompressedFileName(username, []string{"/dir/single dir"}) require.Equal(t, fmt.Sprintf("%s-single dir.zip", username), res) res = getCompressedFileName(username, []string{"/adir/file1", "/adir/file2"}) require.Equal(t, fmt.Sprintf("%s-download.zip", username), res) res = getCompressedFileName(username, []string{"/sub/dir/file1.txt"}) require.Equal(t, fmt.Sprintf("%s-file1.zip", username), res) } func TestRESTAPIDisabled(t *testing.T) { server := httpdServer{ enableWebAdmin: true, enableWebClient: true, enableRESTAPI: false, } err := server.initializeRouter() require.NoError(t, err) assert.False(t, server.enableRESTAPI) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, healthzPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, tokenPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusNotFound, rr.Code) } func TestWebAdminSetupWithInstallCode(t *testing.T) { installationCode = "1234" // delete all the admins admins, err := dataprovider.GetAdmins(100, 0, dataprovider.OrderASC) assert.NoError(t, err) for _, admin := range admins { err = dataprovider.DeleteAdmin(admin.Username, "", "", "") assert.NoError(t, err) } // close the provider and initializes it without creating the default admin providerConf := dataprovider.GetProviderConfig() providerConf.CreateDefaultAdmin = false err = dataprovider.Close() assert.NoError(t, err) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) server := httpdServer{ enableWebAdmin: true, enableWebClient: true, enableRESTAPI: true, } err = server.initializeRouter() require.NoError(t, err) for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webURL, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) } rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) cookie := rr.Header().Get("Set-Cookie") r.Header.Set("Cookie", cookie) parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, r, jwt.TokenFromCookie) assert.NoError(t, err) ctx := r.Context() ctx = jwt.NewContext(ctx, parsedToken, err) r = r.WithContext(ctx) form := make(url.Values) csrfToken := createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath) form.Set(csrfFormToken, csrfToken) form.Set("install_code", installationCode+"5") form.Set("username", defaultAdminUsername) form.Set("password", "password") form.Set("confirm_password", "password") rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) r = r.WithContext(ctx) r.Header.Set("Cookie", cookie) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorSetupInstallCode) _, err = dataprovider.AdminExists(defaultAdminUsername) assert.Error(t, err) form.Set("install_code", installationCode) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) r = r.WithContext(ctx) r.Header.Set("Cookie", cookie) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminMFAPath, rr.Header().Get("Location")) _, err = dataprovider.AdminExists(defaultAdminUsername) assert.NoError(t, err) // delete the admin and test the installation code resolver err = dataprovider.DeleteAdmin(defaultAdminUsername, "", "", "") assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) SetInstallationCodeResolver(func(_ string) string { return "5678" }) for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webURL, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) cookie = rr.Header().Get("Set-Cookie") r.Header.Set("Cookie", cookie) parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, r, jwt.TokenFromCookie) assert.NoError(t, err) ctx = r.Context() ctx = jwt.NewContext(ctx, parsedToken, err) r = r.WithContext(ctx) form = make(url.Values) csrfToken = createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath) form.Set(csrfFormToken, csrfToken) form.Set("install_code", installationCode) form.Set("username", defaultAdminUsername) form.Set("password", "password") form.Set("confirm_password", "password") rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) r = r.WithContext(ctx) r.Header.Set("Cookie", cookie) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorSetupInstallCode) _, err = dataprovider.AdminExists(defaultAdminUsername) assert.Error(t, err) form.Set("install_code", "5678") rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) r = r.WithContext(ctx) r.Header.Set("Cookie", cookie) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminMFAPath, rr.Header().Get("Location")) _, err = dataprovider.AdminExists(defaultAdminUsername) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) providerConf.CreateDefaultAdmin = true err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) installationCode = "" SetInstallationCodeResolver(nil) } func TestDbResetCodeManager(t *testing.T) { if !isSharedProviderSupported() { t.Skip("this test it is not available with this provider") } mgr := newResetCodeManager(1) resetCode := newResetCode("admin", true) err := mgr.Add(resetCode) assert.NoError(t, err) codeGet, err := mgr.Get(resetCode.Code) assert.NoError(t, err) assert.Equal(t, resetCode, codeGet) err = mgr.Delete(resetCode.Code) assert.NoError(t, err) err = mgr.Delete(resetCode.Code) if assert.Error(t, err) { assert.ErrorIs(t, err, util.ErrNotFound) } _, err = mgr.Get(resetCode.Code) assert.ErrorIs(t, err, util.ErrNotFound) // add an expired reset code resetCode = newResetCode("user", false) resetCode.ExpiresAt = time.Now().Add(-24 * time.Hour) err = mgr.Add(resetCode) assert.NoError(t, err) _, err = mgr.Get(resetCode.Code) if assert.Error(t, err) { assert.Contains(t, err.Error(), "reset code expired") } mgr.Cleanup() _, err = mgr.Get(resetCode.Code) assert.ErrorIs(t, err, util.ErrNotFound) dbMgr, ok := mgr.(*dbResetCodeManager) if assert.True(t, ok) { _, err = dbMgr.decodeData("astring") assert.Error(t, err) } } func TestEventRoleFilter(t *testing.T) { defaultVal := "default" req, err := http.NewRequest(http.MethodGet, fsEventsPath+"?role=role1", nil) require.NoError(t, err) role := getRoleFilterForEventSearch(req, defaultVal) assert.Equal(t, defaultVal, role) role = getRoleFilterForEventSearch(req, "") assert.Equal(t, "role1", role) } func TestEventsCSV(t *testing.T) { e := fsEvent{ Status: 1, } data := e.getCSVData() assert.Equal(t, "OK", data[5]) e.Status = 2 data = e.getCSVData() assert.Equal(t, "KO", data[5]) e.Status = 3 data = e.getCSVData() assert.Equal(t, "Quota exceeded", data[5]) } func TestConfigsFromProvider(t *testing.T) { err := dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) c := Conf{ Bindings: []Binding{ { Port: 1234, }, { Port: 80, Security: SecurityConf{ Enabled: true, HTTPSRedirect: true, }, }, }, } err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) configs := dataprovider.Configs{ ACME: &dataprovider.ACMEConfigs{ Domain: "domain.com", Email: "info@domain.com", HTTP01Challenge: dataprovider.ACMEHTTP01Challenge{Port: 80}, Protocols: 1, }, } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) util.CertsBasePath = "" // crt and key empty err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) util.CertsBasePath = filepath.Clean(os.TempDir()) // crt not found err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs := c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) crtPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".crt") err = os.WriteFile(crtPath, nil, 0666) assert.NoError(t, err) // key not found err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) keyPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".key") err = os.WriteFile(keyPath, nil, 0666) assert.NoError(t, err) // acme cert used err = c.loadFromProvider() assert.NoError(t, err) assert.Equal(t, configs.ACME.Domain, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 1) assert.True(t, c.Bindings[0].EnableHTTPS) assert.False(t, c.Bindings[1].EnableHTTPS) // protocols does not match configs.ACME.Protocols = 6 err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) c.acmeDomain = "" err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) err = os.Remove(crtPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) util.CertsBasePath = "" err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestHTTPSRedirect(t *testing.T) { acmeWebRoot := filepath.Join(os.TempDir(), "acme") err := os.MkdirAll(acmeWebRoot, os.ModePerm) assert.NoError(t, err) tokenName := "token" err = os.WriteFile(filepath.Join(acmeWebRoot, tokenName), []byte("val"), 0666) assert.NoError(t, err) acmeConfig := acme.Configuration{ HTTP01Challenge: acme.HTTP01Challenge{WebRoot: acmeWebRoot}, } err = acme.Initialize(acmeConfig, configDir, true) require.NoError(t, err) forwardedHostHeader := "X-Forwarded-Host" server := httpdServer{ binding: Binding{ Security: SecurityConf{ Enabled: true, HTTPSRedirect: true, HostsProxyHeaders: []string{forwardedHostHeader}, }, }, } err = server.initializeRouter() require.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, path.Join(acmeChallengeURI, tokenName), nil) assert.NoError(t, err) r.Host = "localhost" r.RequestURI = path.Join(acmeChallengeURI, tokenName) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) r.RequestURI = webAdminLoginPath server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code, rr.Body.String()) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) r.RequestURI = webAdminLoginPath r.Header.Set(forwardedHostHeader, "sftpgo.com") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), "https://sftpgo.com") server.binding.Security.HTTPSHost = "myhost:1044" rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) r.RequestURI = webAdminLoginPath server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), "https://myhost:1044") err = os.RemoveAll(acmeWebRoot) assert.NoError(t, err) } func TestDisabledAdminLoginMethods(t *testing.T) { server := httpdServer{ binding: Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: true, EnableRESTAPI: true, DisabledLoginMethods: 20, }, enableWebAdmin: true, enableWebClient: true, enableRESTAPI: true, } err := server.initializeRouter() require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() rr := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, tokenPath, nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, defaultAdminUsername, "forgot-password"), nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, defaultAdminUsername, "reset-password"), nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) } func TestDisabledUserLoginMethods(t *testing.T) { server := httpdServer{ binding: Binding{ Address: "", Port: 8080, EnableWebAdmin: true, EnableWebClient: true, EnableRESTAPI: true, DisabledLoginMethods: 40, }, enableWebAdmin: true, enableWebClient: true, enableRESTAPI: true, } err := server.initializeRouter() require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() rr := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, userTokenPath, nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, userPath+"/user/forgot-password", nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, userPath+"/user/reset-password", nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, webClientLoginPath, nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, nil) require.NoError(t, err) testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) } func TestGetLogEventString(t *testing.T) { assert.Equal(t, "Login failed", getLogEventString(notifier.LogEventTypeLoginFailed)) assert.Equal(t, "Login with non-existent user", getLogEventString(notifier.LogEventTypeLoginNoUser)) assert.Equal(t, "No login tried", getLogEventString(notifier.LogEventTypeNoLoginTried)) assert.Equal(t, "Algorithm negotiation failed", getLogEventString(notifier.LogEventTypeNotNegotiated)) assert.Equal(t, "Login succeeded", getLogEventString(notifier.LogEventTypeLoginOK)) assert.Empty(t, getLogEventString(0)) } func TestUserQuotaUsage(t *testing.T) { usage := userQuotaUsage{ QuotaSize: 100, } require.True(t, usage.HasQuotaInfo()) require.NotEmpty(t, usage.GetQuotaSize()) providerConf := dataprovider.GetProviderConfig() quotaTracking := dataprovider.GetQuotaTracking() providerConf.TrackQuota = 0 err := dataprovider.Close() assert.NoError(t, err) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) assert.False(t, usage.HasQuotaInfo()) providerConf.TrackQuota = quotaTracking err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) usage.QuotaSize = 0 assert.False(t, usage.HasQuotaInfo()) assert.Empty(t, usage.GetQuotaSize()) assert.Equal(t, 0, usage.GetQuotaSizePercentage()) assert.False(t, usage.IsQuotaSizeLow()) assert.False(t, usage.IsDiskQuotaLow()) assert.False(t, usage.IsQuotaLow()) usage.UsedQuotaSize = 9 assert.NotEmpty(t, usage.GetQuotaSize()) usage.QuotaSize = 10 assert.True(t, usage.IsQuotaSizeLow()) assert.True(t, usage.IsDiskQuotaLow()) assert.True(t, usage.IsQuotaLow()) usage.DownloadDataTransfer = 1 assert.True(t, usage.HasQuotaInfo()) assert.True(t, usage.HasTranferQuota()) assert.Empty(t, usage.GetQuotaFiles()) assert.Equal(t, 0, usage.GetQuotaFilesPercentage()) usage.QuotaFiles = 1 assert.NotEmpty(t, usage.GetQuotaFiles()) usage.QuotaFiles = 0 usage.UsedQuotaFiles = 9 assert.NotEmpty(t, usage.GetQuotaFiles()) usage.QuotaFiles = 10 usage.DownloadDataTransfer = 0 assert.True(t, usage.IsQuotaFilesLow()) assert.True(t, usage.IsDiskQuotaLow()) assert.False(t, usage.IsTotalTransferQuotaLow()) assert.False(t, usage.IsUploadTransferQuotaLow()) assert.False(t, usage.IsDownloadTransferQuotaLow()) assert.Equal(t, 0, usage.GetTotalTransferQuotaPercentage()) assert.Equal(t, 0, usage.GetUploadTransferQuotaPercentage()) assert.Equal(t, 0, usage.GetDownloadTransferQuotaPercentage()) assert.Empty(t, usage.GetTotalTransferQuota()) assert.Empty(t, usage.GetUploadTransferQuota()) assert.Empty(t, usage.GetDownloadTransferQuota()) usage.TotalDataTransfer = 3 usage.UsedUploadDataTransfer = 1 * 1048576 assert.NotEmpty(t, usage.GetTotalTransferQuota()) usage.TotalDataTransfer = 0 assert.NotEmpty(t, usage.GetTotalTransferQuota()) assert.NotEmpty(t, usage.GetUploadTransferQuota()) usage.UploadDataTransfer = 2 assert.NotEmpty(t, usage.GetUploadTransferQuota()) usage.UsedDownloadDataTransfer = 1 * 1048576 assert.NotEmpty(t, usage.GetDownloadTransferQuota()) usage.DownloadDataTransfer = 2 assert.NotEmpty(t, usage.GetDownloadTransferQuota()) assert.False(t, usage.IsTransferQuotaLow()) usage.UsedDownloadDataTransfer = 8 * 1048576 usage.TotalDataTransfer = 10 assert.True(t, usage.IsTotalTransferQuotaLow()) assert.True(t, usage.IsTransferQuotaLow()) usage.TotalDataTransfer = 0 usage.UploadDataTransfer = 0 usage.DownloadDataTransfer = 0 assert.False(t, usage.IsTransferQuotaLow()) usage.UploadDataTransfer = 10 usage.UsedUploadDataTransfer = 9 * 1048576 assert.True(t, usage.IsUploadTransferQuotaLow()) assert.True(t, usage.IsTransferQuotaLow()) usage.DownloadDataTransfer = 10 usage.UsedDownloadDataTransfer = 9 * 1048576 assert.True(t, usage.IsDownloadTransferQuotaLow()) assert.True(t, usage.IsTransferQuotaLow()) } func TestShareRedirectURL(t *testing.T) { shareID := util.GenerateUniqueID() base := path.Join(webClientPubSharesPath, shareID) next := path.Join(webClientPubSharesPath, shareID, "browse") ok, res := checkShareRedirectURL(next, base) assert.True(t, ok) assert.Equal(t, next, res) next = path.Join(webClientPubSharesPath, shareID, "browse") + "?a=b" ok, res = checkShareRedirectURL(next, base) assert.True(t, ok) assert.Equal(t, next, res) next = path.Join(webClientPubSharesPath, shareID) ok, res = checkShareRedirectURL(next, base) assert.True(t, ok) assert.Equal(t, path.Join(base, "download"), res) next = path.Join(webClientEditFilePath, shareID) ok, res = checkShareRedirectURL(next, base) assert.False(t, ok) assert.Empty(t, res) next = path.Join(webClientPubSharesPath, shareID) + "?compress=false&a=b" ok, res = checkShareRedirectURL(next, base) assert.True(t, ok) assert.Equal(t, path.Join(base, "download?compress=false&a=b"), res) next = path.Join(webClientPubSharesPath, shareID) + "?compress=true&b=c" ok, res = checkShareRedirectURL(next, base) assert.True(t, ok) assert.Equal(t, path.Join(base, "download?compress=true&b=c"), res) ok, res = checkShareRedirectURL("http://foo\x7f.com/ab", "http://foo\x7f.com/") assert.False(t, ok) assert.Empty(t, res) ok, res = checkShareRedirectURL("http://foo.com/?foo\nbar", "http://foo.com") assert.False(t, ok) assert.Empty(t, res) } func TestI18NMessages(t *testing.T) { msg := i18nListDirMsg(http.StatusForbidden) require.Equal(t, util.I18nErrorDirList403, msg) msg = i18nListDirMsg(http.StatusInternalServerError) require.Equal(t, util.I18nErrorDirListGeneric, msg) msg = i18nFsMsg(http.StatusForbidden) require.Equal(t, util.I18nError403Message, msg) msg = i18nFsMsg(http.StatusInternalServerError) require.Equal(t, util.I18nErrorFsGeneric, msg) } func TestI18NErrors(t *testing.T) { err := util.NewValidationError("error text") errI18n := util.NewI18nError(err, util.I18nError500Message) assert.ErrorIs(t, errI18n, util.ErrValidation) assert.Equal(t, err.Error(), errI18n.Error()) assert.Equal(t, util.I18nError500Message, getI18NErrorString(errI18n, "")) assert.Equal(t, util.I18nError500Message, errI18n.Message) assert.Equal(t, "{}", errI18n.Args()) var e1 *util.ValidationError assert.ErrorAs(t, errI18n, &e1) var e2 *util.I18nError assert.ErrorAs(t, errI18n, &e2) err2 := util.NewI18nError(fs.ErrNotExist, util.I18nError500Message) assert.ErrorIs(t, err2, &util.I18nError{}) assert.ErrorIs(t, err2, fs.ErrNotExist) assert.NotErrorIs(t, err2, fs.ErrExist) assert.Equal(t, util.I18nError403Message, getI18NErrorString(fs.ErrClosed, util.I18nError403Message)) errorString := getI18NErrorString(nil, util.I18nError500Message) assert.Equal(t, util.I18nError500Message, errorString) errI18nWrap := util.NewI18nError(errI18n, util.I18nError404Message) assert.Equal(t, util.I18nError500Message, errI18nWrap.Message) errI18n = util.NewI18nError(err, util.I18nError500Message, util.I18nErrorArgs(map[string]any{"a": "b"})) assert.Equal(t, util.I18nError500Message, errI18n.Message) assert.Equal(t, `{"a":"b"}`, errI18n.Args()) } func TestConvertEnabledLoginMethods(t *testing.T) { b := Binding{ EnabledLoginMethods: 0, DisabledLoginMethods: 1, } b.convertLoginMethods() assert.Equal(t, 1, b.DisabledLoginMethods) b.DisabledLoginMethods = 0 b.EnabledLoginMethods = 1 b.convertLoginMethods() assert.Equal(t, 14, b.DisabledLoginMethods) b.DisabledLoginMethods = 0 b.EnabledLoginMethods = 2 b.convertLoginMethods() assert.Equal(t, 13, b.DisabledLoginMethods) b.DisabledLoginMethods = 0 b.EnabledLoginMethods = 3 b.convertLoginMethods() assert.Equal(t, 12, b.DisabledLoginMethods) b.DisabledLoginMethods = 0 b.EnabledLoginMethods = 4 b.convertLoginMethods() assert.Equal(t, 11, b.DisabledLoginMethods) b.DisabledLoginMethods = 0 b.EnabledLoginMethods = 7 b.convertLoginMethods() assert.Equal(t, 8, b.DisabledLoginMethods) b.DisabledLoginMethods = 0 b.EnabledLoginMethods = 15 b.convertLoginMethods() assert.Equal(t, 0, b.DisabledLoginMethods) } func TestValidateBaseURL(t *testing.T) { tests := []struct { name string inputURL string expectedURL string expectErr bool }{ { name: "Valid HTTPS URL", inputURL: "https://sftp.example.com", expectedURL: "https://sftp.example.com", expectErr: false, }, { name: "Remove trailing slash", inputURL: "https://sftp.example.com/", expectedURL: "https://sftp.example.com", expectErr: false, }, { name: "Remove multiple trailing slashes", inputURL: "http://192.168.1.100:8080///", expectedURL: "http://192.168.1.100:8080", expectErr: false, }, { name: "Empty BaseURL (optional case)", inputURL: "", expectedURL: "", expectErr: false, }, { name: "Unsupported scheme (FTP)", inputURL: "ftp://files.example.com", expectErr: true, }, { name: "Malformed URL string", inputURL: "not-a-url", expectErr: true, }, { name: "Missing Host", inputURL: "https://", expectErr: true, }, { name: "Preserve path without trailing slash", inputURL: "https://example.com/sftp/", expectedURL: "https://example.com/sftp", expectErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b := &Binding{ BaseURL: tt.inputURL, } err := b.validateBaseURL() if (err != nil) != tt.expectErr { t.Errorf("validateBaseURL() error = %v, expectErr %v", err, tt.expectErr) return } if !tt.expectErr && b.BaseURL != tt.expectedURL { t.Errorf("validateBaseURL() got = %v, want %v", b.BaseURL, tt.expectedURL) } }) } } func getCSRFTokenFromBody(body io.Reader) (string, error) { doc, err := html.Parse(body) if err != nil { return "", err } var csrfToken string var f func(*html.Node) f = func(n *html.Node) { if n.Type == html.ElementNode && n.Data == "input" { var name, value string for _, attr := range n.Attr { if attr.Key == "value" { value = attr.Val } if attr.Key == "name" { name = attr.Val } } if name == csrfFormToken { csrfToken = value return } } for c := n.FirstChild; c != nil; c = c.NextSibling { f(c) } } f(doc) if csrfToken == "" { return "", errors.New("CSRF token not found") } return csrfToken, nil } func isSharedProviderSupported() bool { // SQLite shares the implementation with other SQL-based provider but it makes no sense // to use it outside test cases switch dataprovider.GetProviderStatus().Driver { case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: return true default: return false } } ================================================ FILE: internal/httpd/middleware.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "errors" "fmt" "io/fs" "net/http" "net/url" "slices" "strings" "time" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( forwardedProtoKey = &contextKey{"forwarded proto"} errInvalidToken = errors.New("invalid JWT token") ) type contextKey struct { name string } func (k *contextKey) String() string { return "context value " + k.name } func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error { token, err := jwt.FromContext(r.Context()) var redirectPath string if audience == tokenAudienceWebAdmin { redirectPath = webAdminLoginPath } else { redirectPath = webClientLoginPath if uri := r.RequestURI; strings.HasPrefix(uri, webClientFilesPath) { redirectPath += "?next=" + url.QueryEscape(uri) //nolint:goconst } } isAPIToken := (audience == tokenAudienceAPI || audience == tokenAudienceAPIUser) doRedirect := func(message string, err error) { if isAPIToken { sendAPIResponse(w, r, err, message, http.StatusUnauthorized) } else { http.Redirect(w, r, redirectPath, http.StatusFound) } } if err != nil { logger.Debug(logSender, "", "error getting jwt token: %v", err) doRedirect(http.StatusText(http.StatusUnauthorized), err) return errInvalidToken } if isTokenInvalidated(r) { logger.Debug(logSender, "", "the token has been invalidated") doRedirect("Your token is no longer valid", nil) return errInvalidToken } // a user with a partial token will be always redirected to the appropriate two factor auth page if err := checkPartialAuth(w, r, audience, token.Audience); err != nil { return err } if !token.Audience.Contains(audience) { logger.Debug(logSender, "", "the token is not valid for audience %q", audience) doRedirect("Your token audience is not valid", nil) return errInvalidToken } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := validateIPForToken(token, ipAddr); err != nil { logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.ID, ipAddr) doRedirect("Your token is not valid", nil) return err } if err := checkTokenSignature(r, token); err != nil { doRedirect("Your token is no longer valid", nil) return err } return nil } func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error { token, err := jwt.FromContext(r.Context()) var notFoundFunc func(w http.ResponseWriter, r *http.Request, err error) if audience == tokenAudienceWebAdminPartial { notFoundFunc = s.renderNotFoundPage } else { notFoundFunc = s.renderClientNotFoundPage } if err != nil { notFoundFunc(w, r, nil) return errInvalidToken } if isTokenInvalidated(r) { notFoundFunc(w, r, nil) return errInvalidToken } if !token.Audience.Contains(audience) { logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.ID, audience) notFoundFunc(w, r, nil) return errInvalidToken } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := validateIPForToken(token, ipAddr); err != nil { logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.ID, ipAddr) notFoundFunc(w, r, nil) return err } return nil } func (s *httpdServer) jwtAuthenticatorPartial(audience tokenAudience) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := s.validateJWTPartialToken(w, r, audience); err != nil { return } // Token is authenticated, pass it through next.ServeHTTP(w, r) }) } } func jwtAuthenticatorAPI(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := validateJWTToken(w, r, tokenAudienceAPI); err != nil { return } // Token is authenticated, pass it through next.ServeHTTP(w, r) }) } func jwtAuthenticatorAPIUser(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := validateJWTToken(w, r, tokenAudienceAPIUser); err != nil { return } // Token is authenticated, pass it through next.ServeHTTP(w, r) }) } func jwtAuthenticatorWebAdmin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := validateJWTToken(w, r, tokenAudienceWebAdmin); err != nil { return } // Token is authenticated, pass it through next.ServeHTTP(w, r) }) } func jwtAuthenticatorWebClient(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := validateJWTToken(w, r, tokenAudienceWebClient); err != nil { return } // Token is authenticated, pass it through next.ServeHTTP(w, r) }) } func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, err := jwt.FromContext(r.Context()) if err != nil { if isWebRequest(r) { s.renderClientBadRequestPage(w, r, err) } else { sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) } return } // for web client perms are negated and not granted if claims.HasPerm(perm) { if isWebRequest(r) { s.renderClientForbiddenPage(w, r, errors.New("you don't have permission for this action")) } else { sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden) } return } next.ServeHTTP(w, r) }) } } // checkAuthRequirements checks if the user must set a second factor auth or change the password func (s *httpdServer) checkAuthRequirements(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, err := jwt.FromContext(r.Context()) if err != nil { if isWebRequest(r) { if isWebClientRequest(r) { s.renderClientBadRequestPage(w, r, err) } else { s.renderBadRequestPage(w, r, err) } } else { sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) } return } if claims.MustSetTwoFactorAuth || claims.MustChangePassword { var err error if claims.MustSetTwoFactorAuth { if len(claims.RequiredTwoFactorProtocols) > 0 { protocols := strings.Join(claims.RequiredTwoFactorProtocols, ", ") err = util.NewI18nError( util.NewGenericError( fmt.Sprintf("Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols: %v", protocols)), util.I18nError2FARequired, util.I18nErrorArgs(map[string]any{ "val": protocols, }), ) } else { err = util.NewI18nError( util.NewGenericError("Two-factor authentication requirements not met, please configure two-factor authentication"), util.I18nError2FARequiredGeneric, ) } } else { err = util.NewI18nError( util.NewGenericError("Password change required. Please set a new password to continue to use your account"), util.I18nErrorChangePwdRequired, ) } if isWebRequest(r) { if isWebClientRequest(r) { s.renderClientForbiddenPage(w, r, err) } else { s.renderForbiddenPage(w, r, err) } } else { sendAPIResponse(w, r, err, "", http.StatusForbidden) } return } next.ServeHTTP(w, r) }) } func (s *httpdServer) requireBuiltinLogin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if isLoggedInWithOIDC(r) { err := util.NewI18nError( util.NewGenericError("This feature is not available if you are logged in with OpenID"), util.I18nErrorNoOIDCFeature, ) if isWebClientRequest(r) { s.renderClientForbiddenPage(w, r, err) } else { s.renderForbiddenPage(w, r, err) } return } next.ServeHTTP(w, r) }) } func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, err := jwt.FromContext(r.Context()) if err != nil { if isWebRequest(r) { s.renderBadRequestPage(w, r, err) } else { sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) } return } for _, perm := range perms { if !claims.HasPerm(perm) { if isWebRequest(r) { s.renderForbiddenPage(w, r, util.NewI18nError(fs.ErrPermission, util.I18nError403Message)) } else { sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden) } return } } next.ServeHTTP(w, r) }) } } func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString := r.Header.Get(csrfHeaderToken) token, err := jwt.VerifyToken(s.csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating CSRF header: %v", err) sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden) return } if !token.Audience.Contains(tokenAudienceCSRF) { logger.Debug(logSender, "", "error validating CSRF header token audience") sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) return } if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { logger.Debug(logSender, "", "error validating CSRF header IP audience") sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) return } if err := checkCSRFTokenRef(r, token); err != nil { sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } func checkNodeToken(tokenAuth *jwt.Signer) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { bearer := r.Header.Get(dataprovider.NodeTokenHeader) if bearer == "" { next.ServeHTTP(w, r) return } const prefix = "Bearer " if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) { bearer = bearer[len(prefix):] } if invalidatedJWTTokens.Get(bearer) { logger.Debug(logSender, "", "the node token has been invalidated") sendAPIResponse(w, r, fmt.Errorf("the provided token is not valid"), "", http.StatusUnauthorized) return } claims, err := dataprovider.AuthenticateNodeToken(bearer) if err != nil { logger.Debug(logSender, "", "unable to authenticate node token %q: %v", bearer, err) sendAPIResponse(w, r, fmt.Errorf("the provided token cannot be authenticated"), "", http.StatusUnauthorized) return } defer invalidatedJWTTokens.Add(bearer, time.Now().Add(2*time.Minute).UTC()) c := &jwt.Claims{ Username: claims.Username, Permissions: claims.Permissions, NodeID: dataprovider.GetNodeName(), Role: claims.Role, } token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr), getTokenDuration(tokenAudienceAPI)) if err != nil { sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } resp := c.BuildTokenResponse(token) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) next.ServeHTTP(w, r) }) } } func checkAPIKeyAuth(tokenAuth *jwt.Signer, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiKey := r.Header.Get("X-SFTPGO-API-KEY") if apiKey == "" { next.ServeHTTP(w, r) return } keyParams := strings.SplitN(apiKey, ".", 3) if len(keyParams) < 2 { logger.Debug(logSender, "", "invalid api key %q", apiKey) sendAPIResponse(w, r, errors.New("the provided api key is not valid"), "", http.StatusBadRequest) return } keyID := keyParams[0] key := keyParams[1] apiUser := "" if len(keyParams) > 2 { apiUser = keyParams[2] } k, err := dataprovider.APIKeyExists(keyID) if err != nil { handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), util.NewRecordNotFoundError("invalid api key")) //nolint:errcheck logger.Debug(logSender, "", "invalid api key %q: %v", apiKey, err) sendAPIResponse(w, r, errors.New("the provided api key is not valid"), "", http.StatusBadRequest) return } if k.Scope != scope { handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), dataprovider.ErrInvalidCredentials) //nolint:errcheck logger.Debug(logSender, "", "unable to authenticate api key %q: invalid scope: got %d, wanted: %d", apiKey, k.Scope, scope) sendAPIResponse(w, r, fmt.Errorf("the provided api key is invalid for this request"), "", http.StatusForbidden) return } if err := k.Authenticate(key); err != nil { handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), dataprovider.ErrInvalidCredentials) //nolint:errcheck logger.Debug(logSender, "", "unable to authenticate api key %q: %v", apiKey, err) sendAPIResponse(w, r, fmt.Errorf("the provided api key cannot be authenticated"), "", http.StatusUnauthorized) return } if scope == dataprovider.APIKeyScopeAdmin { if k.Admin != "" { apiUser = k.Admin } if err := authenticateAdminWithAPIKey(apiUser, keyID, tokenAuth, r); err != nil { handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), err) //nolint:errcheck logger.Debug(logSender, "", "unable to authenticate admin %q associated with api key %q: %v", apiUser, apiKey, err) sendAPIResponse(w, r, fmt.Errorf("the admin associated with the provided api key cannot be authenticated"), "", http.StatusUnauthorized) return } common.DelayLogin(nil) } else { if k.User != "" { apiUser = k.User } if err := authenticateUserWithAPIKey(apiUser, keyID, tokenAuth, r); err != nil { logger.Debug(logSender, "", "unable to authenticate user %q associated with api key %q: %v", apiUser, apiKey, err) updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: apiUser}}, dataprovider.LoginMethodPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), err, r) code := http.StatusUnauthorized if errors.Is(err, common.ErrInternalFailure) { code = http.StatusInternalServerError } sendAPIResponse(w, r, errors.New("the user associated with the provided api key cannot be authenticated"), "", code) return } updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: apiUser}}, dataprovider.LoginMethodPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), nil, r) } dataprovider.UpdateAPIKeyLastUse(&k) //nolint:errcheck next.ServeHTTP(w, r) }) } } func forbidAPIKeyAuthentication(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } if claims.APIKeyID != "" { sendAPIResponse(w, r, nil, "API key authentication is not allowed", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error { if username == "" { return errors.New("the provided key is not associated with any admin and no username was provided") } admin, err := dataprovider.AdminExists(username) if err != nil { return err } if !admin.Filters.AllowAPIKeyAuth { return fmt.Errorf("API key authentication disabled for admin %q", admin.Username) } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := admin.CanLogin(ipAddr); err != nil { return err } c := &jwt.Claims{ Username: admin.Username, Permissions: admin.Permissions, Role: admin.Role, APIKeyID: keyID, } c.Subject = admin.GetSignature() token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, ipAddr, getTokenDuration(tokenAudienceAPI)) if err != nil { return err } resp := c.BuildTokenResponse(token) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) dataprovider.UpdateAdminLastLogin(&admin) common.DelayLogin(nil) return nil } func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error { ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) protocol := common.ProtocolHTTP if username == "" { err := errors.New("the provided key is not associated with any user and no username was provided") updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err, r) return err } if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil { return err } user, err := dataprovider.GetUserWithGroupSettings(username, "") if err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err, r) return err } if !user.Filters.AllowAPIKeyAuth { err := fmt.Errorf("API key authentication disabled for user %q", user.Username) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) return err } if err := user.CheckLoginConditions(); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) return err } connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) return err } defer user.CloseFs() //nolint:errcheck err = user.CheckFsRoot(connectionID) if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) return common.ErrInternalFailure } c := &jwt.Claims{ Username: user.Username, Permissions: user.Filters.WebClient, Role: user.Role, APIKeyID: keyID, } c.Subject = user.GetSignature() token, err := tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser)) if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) return err } resp := c.BuildTokenResponse(token) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) dataprovider.UpdateLastLogin(&user) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, nil, r) return nil } func checkPartialAuth(w http.ResponseWriter, r *http.Request, audience string, tokenAudience []string) error { if audience == tokenAudienceWebAdmin && slices.Contains(tokenAudience, tokenAudienceWebAdminPartial) { http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound) return errInvalidToken } if audience == tokenAudienceWebClient && slices.Contains(tokenAudience, tokenAudienceWebClientPartial) { http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound) return errInvalidToken } return nil } func cacheControlMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, private") next.ServeHTTP(w, r) }) } func cleanCacheControlMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Del("Cache-Control") next.ServeHTTP(w, r) }) } ================================================ FILE: internal/httpd/oauth2.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/json" "errors" "sync" "time" "golang.org/x/oauth2" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( oauth2Mgr oauth2Manager ) func newOAuth2Manager(isShared int) oauth2Manager { if isShared == 1 { logger.Info(logSender, "", "using provider OAuth2 manager") return &dbOAuth2Manager{} } logger.Info(logSender, "", "using memory OAuth2 manager") return &memoryOAuth2Manager{ pendingAuths: make(map[string]oauth2PendingAuth), } } type oauth2PendingAuth struct { State string `json:"state"` Provider int `json:"provider"` ClientID string `json:"client_id"` ClientSecret *kms.Secret `json:"client_secret"` RedirectURL string `json:"redirect_url"` IssuedAt int64 `json:"issued_at"` Verifier string `json:"verifier"` } func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth { return oauth2PendingAuth{ State: util.GenerateOpaqueString(), Provider: provider, ClientID: clientID, ClientSecret: clientSecret, RedirectURL: redirectURL, IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()), Verifier: oauth2.GenerateVerifier(), } } type oauth2Manager interface { addPendingAuth(pendingAuth oauth2PendingAuth) removePendingAuth(state string) getPendingAuth(state string) (oauth2PendingAuth, error) cleanup() } type memoryOAuth2Manager struct { mu sync.RWMutex pendingAuths map[string]oauth2PendingAuth } func (o *memoryOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) { o.mu.Lock() defer o.mu.Unlock() o.pendingAuths[pendingAuth.State] = pendingAuth } func (o *memoryOAuth2Manager) removePendingAuth(state string) { o.mu.Lock() defer o.mu.Unlock() delete(o.pendingAuths, state) } func (o *memoryOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) { o.mu.RLock() defer o.mu.RUnlock() authReq, ok := o.pendingAuths[state] if !ok { return oauth2PendingAuth{}, errors.New("oauth2: no auth request found for the specified state") } diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt if diff > authStateValidity { return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old") } return authReq, nil } func (o *memoryOAuth2Manager) cleanup() { o.mu.Lock() defer o.mu.Unlock() for k, auth := range o.pendingAuths { diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt // remove old pending auth requests if diff < 0 || diff > authStateValidity { delete(o.pendingAuths, k) } } } type dbOAuth2Manager struct{} func (o *dbOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) { if err := pendingAuth.ClientSecret.Encrypt(); err != nil { logger.Error(logSender, "", "unable to encrypt oauth2 secret: %v", err) return } session := dataprovider.Session{ Key: pendingAuth.State, Data: pendingAuth, Type: dataprovider.SessionTypeOAuth2Auth, Timestamp: pendingAuth.IssuedAt + authStateValidity, } dataprovider.AddSharedSession(session) //nolint:errcheck } func (o *dbOAuth2Manager) removePendingAuth(state string) { dataprovider.DeleteSharedSession(state, dataprovider.SessionTypeOAuth2Auth) //nolint:errcheck } func (o *dbOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) { session, err := dataprovider.GetSharedSession(state, dataprovider.SessionTypeOAuth2Auth) if err != nil { return oauth2PendingAuth{}, errors.New("oauth2: unable to get the auth request for the specified state") } if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { // expired return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old") } return o.decodePendingAuthData(session.Data) } func (o *dbOAuth2Manager) decodePendingAuthData(data any) (oauth2PendingAuth, error) { if val, ok := data.([]byte); ok { authReq := oauth2PendingAuth{} err := json.Unmarshal(val, &authReq) if err != nil { return authReq, err } err = authReq.ClientSecret.TryDecrypt() return authReq, err } logger.Error(logSender, "", "invalid oauth2 auth request data type %T", data) return oauth2PendingAuth{}, errors.New("oauth2: invalid auth request data") } func (o *dbOAuth2Manager) cleanup() { dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOAuth2Auth, time.Now()) //nolint:errcheck } ================================================ FILE: internal/httpd/oauth2_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/json" "testing" "time" "github.com/rs/xid" sdkkms "github.com/sftpgo/sdk/kms" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/util" ) func TestMemoryOAuth2Manager(t *testing.T) { mgr := newOAuth2Manager(0) m, ok := mgr.(*memoryOAuth2Manager) require.True(t, ok) require.Len(t, m.pendingAuths, 0) _, err := m.getPendingAuth(xid.New().String()) require.Error(t, err) assert.Contains(t, err.Error(), "no auth request found") auth := newOAuth2PendingAuth(1, "https://...", "cid", kms.NewPlainSecret("mysecret")) m.addPendingAuth(auth) require.Len(t, m.pendingAuths, 1) a, err := m.getPendingAuth(auth.State) assert.NoError(t, err) assert.Equal(t, auth.State, a.State) assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus()) m.removePendingAuth(auth.State) _, err = m.getPendingAuth(auth.State) require.Error(t, err) assert.Contains(t, err.Error(), "no auth request found") require.Len(t, m.pendingAuths, 0) state := xid.New().String() auth = oauth2PendingAuth{ State: state, Provider: 1, IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()), } m.addPendingAuth(auth) auth = oauth2PendingAuth{ State: xid.New().String(), Provider: 1, IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), } m.addPendingAuth(auth) require.Len(t, m.pendingAuths, 2) _, err = m.getPendingAuth(auth.State) require.Error(t, err) assert.Contains(t, err.Error(), "auth request is too old") m.cleanup() require.Len(t, m.pendingAuths, 1) m.removePendingAuth(state) require.Len(t, m.pendingAuths, 0) } func TestDbOAuth2Manager(t *testing.T) { if !isSharedProviderSupported() { t.Skip("this test it is not available with this provider") } mgr := newOAuth2Manager(1) m, ok := mgr.(*dbOAuth2Manager) require.True(t, ok) _, err := m.getPendingAuth(xid.New().String()) require.Error(t, err) auth := newOAuth2PendingAuth(1, "https://...", "client_id", kms.NewPlainSecret("my db secret")) m.addPendingAuth(auth) a, err := m.getPendingAuth(auth.State) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus()) session, err := dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) assert.NoError(t, err) authReq := oauth2PendingAuth{} err = json.Unmarshal(session.Data.([]byte), &authReq) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, authReq.ClientSecret.GetStatus()) m.cleanup() _, err = m.getPendingAuth(auth.State) assert.NoError(t, err) m.removePendingAuth(auth.State) _, err = m.getPendingAuth(auth.State) assert.Error(t, err) auth = oauth2PendingAuth{ State: xid.New().String(), Provider: 1, IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), ClientSecret: kms.NewPlainSecret("db secret"), } m.addPendingAuth(auth) _, err = m.getPendingAuth(auth.State) assert.Error(t, err) _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) assert.NoError(t, err) m.cleanup() _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) assert.Error(t, err) _, err = m.decodePendingAuthData("not a byte array") require.Error(t, err) assert.Contains(t, err.Error(), "invalid auth request data") _, err = m.decodePendingAuthData([]byte("{not a json")) require.Error(t, err) // adding a request with a non plain secret will fail auth = oauth2PendingAuth{ State: xid.New().String(), Provider: 1, IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), ClientSecret: kms.NewPlainSecret("db secret"), } auth.ClientSecret.SetStatus(sdkkms.SecretStatusSecretBox) m.addPendingAuth(auth) _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) assert.Error(t, err) asJSON, err := json.Marshal(auth) assert.NoError(t, err) _, err = m.decodePendingAuthData(asJSON) assert.Error(t, err) } ================================================ FILE: internal/httpd/oidc.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "errors" "fmt" "net/http" "net/url" "slices" "strings" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/rs/xid" "golang.org/x/oauth2" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( oidcCookieKey = "oidc" adminRoleFieldValue = "admin" authStateValidity = 2 * 60 * 1000 // 2 minutes tokenUpdateInterval = 3 * 60 * 1000 // 3 minutes tokenDeleteInterval = 2 * 3600 * 1000 // 2 hours ) var ( oidcTokenKey = &contextKey{"OIDC token key"} oidcGeneratedToken = &contextKey{"OIDC generated token"} ) // OAuth2Config defines an interface for OAuth2 methods, so we can mock them type OAuth2Config interface { AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource } // OIDCTokenVerifier defines an interface for OpenID token verifier, so we can mock them type OIDCTokenVerifier interface { Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) } // OIDC defines the OpenID Connect configuration type OIDC struct { // ClientID is the application's ID ClientID string `json:"client_id" mapstructure:"client_id"` // ClientSecret is the application's secret ClientSecret string `json:"client_secret" mapstructure:"client_secret"` ClientSecretFile string `json:"client_secret_file" mapstructure:"client_secret_file"` // ConfigURL is the identifier for the service. // SFTPGo will try to retrieve the provider configuration on startup and then // will refuse to start if it fails to connect to the specified URL ConfigURL string `json:"config_url" mapstructure:"config_url"` // RedirectBaseURL is the base URL to redirect to after OpenID authentication. // The suffix "/web/oidc/redirect" will be added to this base URL, adding also the // "web_root" if configured RedirectBaseURL string `json:"redirect_base_url" mapstructure:"redirect_base_url"` // ID token claims field to map to the SFTPGo username UsernameField string `json:"username_field" mapstructure:"username_field"` // Optional ID token claims field to map to a SFTPGo role. // If the defined ID token claims field is set to "admin" the authenticated user // is mapped to an SFTPGo admin. // You don't need to specify this field if you want to use OpenID only for the // Web Client UI RoleField string `json:"role_field" mapstructure:"role_field"` // If set, the `RoleField` is ignored and the SFTPGo role is assumed based on // the login link used ImplicitRoles bool `json:"implicit_roles" mapstructure:"implicit_roles"` // Scopes required by the OAuth provider to retrieve information about the authenticated user. // The "openid" scope is required. // Refer to your OAuth provider documentation for more information about this Scopes []string `json:"scopes" mapstructure:"scopes"` // Custom token claims fields to pass to the pre-login hook CustomFields []string `json:"custom_fields" mapstructure:"custom_fields"` // InsecureSkipSignatureCheck causes SFTPGo to skip JWT signature validation. // It's intended for special cases where providers, such as Azure, use the "none" // algorithm. Skipping the signature validation can cause security issues InsecureSkipSignatureCheck bool `json:"insecure_skip_signature_check" mapstructure:"insecure_skip_signature_check"` // Debug enables the OIDC debug mode. In debug mode, the received id_token will be logged // at the debug level Debug bool `json:"debug" mapstructure:"debug"` provider *oidc.Provider verifier OIDCTokenVerifier providerLogoutURL string oauth2Config OAuth2Config } func (o *OIDC) isEnabled() bool { return o.provider != nil } func (o *OIDC) hasRoles() bool { return o.isEnabled() && (o.RoleField != "" || o.ImplicitRoles) } func (o *OIDC) getForcedRole(audience string) string { if !o.ImplicitRoles { return "" } if audience == tokenAudienceWebAdmin { return adminRoleFieldValue } return "" } func (o *OIDC) getRedirectURL() string { url := o.RedirectBaseURL if strings.HasSuffix(o.RedirectBaseURL, "/") { url = strings.TrimSuffix(o.RedirectBaseURL, "/") } url += webOIDCRedirectPath logger.Debug(logSender, "", "oidc redirect URL: %q", url) return url } func (o *OIDC) initialize() error { if o.ConfigURL == "" { return nil } if o.UsernameField == "" { return errors.New("oidc: username field cannot be empty") } if o.RedirectBaseURL == "" { return errors.New("oidc: redirect base URL cannot be empty") } if !slices.Contains(o.Scopes, oidc.ScopeOpenID) { return fmt.Errorf("oidc: required scope %q is not set", oidc.ScopeOpenID) } if o.ClientSecretFile != "" { secret, err := util.ReadConfigFromFile(o.ClientSecretFile, configurationDir) if err != nil { return err } o.ClientSecret = secret } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() provider, err := oidc.NewProvider(ctx, o.ConfigURL) if err != nil { return fmt.Errorf("oidc: unable to initialize provider for URL %q: %w", o.ConfigURL, err) } claims := make(map[string]any) // we cannot get an error here because the response body was already parsed as JSON // on provider creation provider.Claims(&claims) //nolint:errcheck endSessionEndPoint, ok := claims["end_session_endpoint"] if ok { if val, ok := endSessionEndPoint.(string); ok { o.providerLogoutURL = val logger.Debug(logSender, "", "oidc end session endpoint %q", o.providerLogoutURL) } } o.provider = provider o.verifier = nil o.oauth2Config = &oauth2.Config{ ClientID: o.ClientID, ClientSecret: o.ClientSecret, Endpoint: o.provider.Endpoint(), RedirectURL: o.getRedirectURL(), Scopes: o.Scopes, } return nil } func (o *OIDC) getVerifier(ctx context.Context) OIDCTokenVerifier { if o.verifier != nil { return o.verifier } return o.provider.VerifierContext(ctx, &oidc.Config{ ClientID: o.ClientID, InsecureSkipSignatureCheck: o.InsecureSkipSignatureCheck, }) } type oidcPendingAuth struct { State string `json:"state"` Nonce string `json:"nonce"` Audience tokenAudience `json:"audience"` IssuedAt int64 `json:"issued_at"` Verifier string `json:"verifier"` } func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth { return oidcPendingAuth{ State: util.GenerateOpaqueString(), Nonce: util.GenerateOpaqueString(), Audience: audience, IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()), Verifier: oauth2.GenerateVerifier(), } } type oidcToken struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type,omitempty"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresAt int64 `json:"expires_at,omitempty"` SessionID string `json:"session_id"` IDToken string `json:"id_token"` Nonce string `json:"nonce"` Username string `json:"username"` Permissions []string `json:"permissions"` HideUserPageSections int `json:"hide_user_page_sections,omitempty"` MustSetTwoFactorAuth bool `json:"must_set_2fa,omitempty"` MustChangePassword bool `json:"must_change_password,omitempty"` RequiredTwoFactorProtocols []string `json:"required_two_factor_protocols,omitempty"` TokenRole string `json:"token_role,omitempty"` // SFTPGo role name Role any `json:"role"` // oidc user role: SFTPGo user or admin CustomFields *map[string]any `json:"custom_fields,omitempty"` Cookie string `json:"cookie"` UsedAt int64 `json:"used_at"` } func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField string, customFields []string, forcedRole string, ) error { getClaimsFields := func() []string { keys := make([]string, 0, len(claims)) for k := range claims { keys = append(keys, k) } return keys } var username string val, ok := getOIDCFieldFromClaims(claims, usernameField) if ok { username, ok = val.(string) } if !ok || username == "" { logger.Warn(logSender, "", "username field %q not found, empty or not a string, claims fields: %+v", usernameField, getClaimsFields()) return errors.New("no username field") } t.Username = username if forcedRole != "" { t.Role = forcedRole } else { t.getRoleFromField(claims, roleField) } t.CustomFields = nil if len(customFields) > 0 { for _, field := range customFields { if val, ok := getOIDCFieldFromClaims(claims, field); ok { if t.CustomFields == nil { customFields := make(map[string]any) t.CustomFields = &customFields } logger.Debug(logSender, "", "custom field %q found in token claims", field) (*t.CustomFields)[field] = val } else { logger.Info(logSender, "", "custom field %q not found in token claims", field) } } } sid, ok := claims["sid"].(string) if ok { t.SessionID = sid } return nil } func (t *oidcToken) getRoleFromField(claims map[string]any, roleField string) { role, ok := getOIDCFieldFromClaims(claims, roleField) if ok { t.Role = role } } func (t *oidcToken) isAdmin() bool { switch v := t.Role.(type) { case string: return v == adminRoleFieldValue case []any: for _, s := range v { if val, ok := s.(string); ok && val == adminRoleFieldValue { return true } } return false default: return false } } func (t *oidcToken) isExpired() bool { if t.ExpiresAt == 0 { return false } return t.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) } func (t *oidcToken) refresh(ctx context.Context, config OAuth2Config, verifier OIDCTokenVerifier, r *http.Request) error { if t.RefreshToken == "" { logger.Debug(logSender, "", "refresh token not set, unable to refresh cookie %q", t.Cookie) return errors.New("refresh token not set") } oauth2Token := oauth2.Token{ AccessToken: t.AccessToken, TokenType: t.TokenType, RefreshToken: t.RefreshToken, } if t.ExpiresAt > 0 { oauth2Token.Expiry = util.GetTimeFromMsecSinceEpoch(t.ExpiresAt) } newToken, err := config.TokenSource(ctx, &oauth2Token).Token() if err != nil { logger.Debug(logSender, "", "unable to refresh token for cookie %q: %v", t.Cookie, err) return err } rawIDToken, ok := newToken.Extra("id_token").(string) if !ok { logger.Debug(logSender, "", "the refreshed token has no id token, cookie %q", t.Cookie) return errors.New("the refreshed token has no id token") } t.AccessToken = newToken.AccessToken t.TokenType = newToken.TokenType t.RefreshToken = newToken.RefreshToken t.IDToken = rawIDToken if !newToken.Expiry.IsZero() { t.ExpiresAt = util.GetTimeAsMsSinceEpoch(newToken.Expiry) } else { t.ExpiresAt = 0 } idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { logger.Debug(logSender, "", "unable to verify refreshed id token for cookie %q: %v", t.Cookie, err) return err } if idToken.Nonce != "" && idToken.Nonce != t.Nonce { logger.Warn(logSender, "", "unable to verify refreshed id token for cookie %q: nonce mismatch, expected: %q, actual: %q", t.Cookie, t.Nonce, idToken.Nonce) return errors.New("the refreshed token nonce mismatch") } claims := make(map[string]any) err = idToken.Claims(&claims) if err != nil { logger.Warn(logSender, "", "unable to get refreshed id token claims for cookie %q: %v", t.Cookie, err) return err } sid, ok := claims["sid"].(string) if ok { t.SessionID = sid } err = t.refreshUser(r) if err != nil { logger.Debug(logSender, "", "unable to refresh user after token refresh for cookie %q: %v", t.Cookie, err) return err } logger.Debug(logSender, "", "oidc token refreshed for user %q, cookie %q", t.Username, t.Cookie) oidcMgr.addToken(*t) return nil } func (t *oidcToken) refreshUser(r *http.Request) error { if t.isAdmin() { admin, err := dataprovider.AdminExists(t.Username) if err != nil { return err } if err := admin.CanLogin(util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { return err } t.Permissions = admin.Permissions t.TokenRole = admin.Role t.HideUserPageSections = admin.Filters.Preferences.HideUserPageSections return nil } user, err := dataprovider.GetUserWithGroupSettings(t.Username, "") if err != nil { return err } if err := user.CheckLoginConditions(); err != nil { return err } if err := checkHTTPClientUser(&user, r, xid.New().String(), true, false); err != nil { return err } t.Permissions = user.Filters.WebClient t.TokenRole = user.Role t.MustSetTwoFactorAuth = user.MustSetSecondFactor() t.MustChangePassword = user.MustChangePassword() t.RequiredTwoFactorProtocols = user.Filters.TwoFactorAuthProtocols return nil } func (t *oidcToken) getUser(r *http.Request) error { ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) params := common.EventParams{ Name: t.Username, IP: ipAddr, Protocol: common.ProtocolOIDC, Timestamp: time.Now(), Status: 1, } if t.isAdmin() { params.Event = common.IDPLoginAdmin _, admin, err := common.HandleIDPLoginEvent(params, t.CustomFields) if err != nil { return err } if admin == nil { a, err := dataprovider.AdminExists(t.Username) if err != nil { return err } admin = &a } if err := admin.CanLogin(ipAddr); err != nil { return err } t.Permissions = admin.Permissions t.TokenRole = admin.Role t.HideUserPageSections = admin.Filters.Preferences.HideUserPageSections dataprovider.UpdateAdminLastLogin(admin) common.DelayLogin(nil) return nil } params.Event = common.IDPLoginUser user, _, err := common.HandleIDPLoginEvent(params, t.CustomFields) if err != nil { return err } if user == nil { u, err := dataprovider.GetUserAfterIDPAuth(t.Username, ipAddr, common.ProtocolOIDC, t.CustomFields) if err != nil { return err } user = &u } if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolOIDC); err != nil { updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r) return fmt.Errorf("access denied: %w", err) } if err := user.CheckLoginConditions(); err != nil { updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r) return err } connectionID := fmt.Sprintf("%s_%s", common.ProtocolOIDC, xid.New().String()) if err := checkHTTPClientUser(user, r, connectionID, true, true); err != nil { updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r) return err } defer user.CloseFs() //nolint:errcheck err = user.CheckFsRoot(connectionID) if err != nil { logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, common.ErrInternalFailure, r) return err } updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, nil, r) dataprovider.UpdateLastLogin(user) t.Permissions = user.Filters.WebClient t.TokenRole = user.Role t.MustSetTwoFactorAuth = user.MustSetSecondFactor() t.MustChangePassword = user.MustChangePassword() t.RequiredTwoFactorProtocols = user.Filters.TwoFactorAuthProtocols return nil } func (s *httpdServer) validateOIDCToken(w http.ResponseWriter, r *http.Request, isAdmin bool) (oidcToken, error) { doRedirect := func() { removeOIDCCookie(w, r) if isAdmin { http.Redirect(w, r, webAdminLoginPath, http.StatusFound) return } http.Redirect(w, r, webClientLoginPath, http.StatusFound) } cookie, err := r.Cookie(oidcCookieKey) if err != nil { logger.Debug(logSender, "", "no oidc cookie, redirecting to login page") doRedirect() return oidcToken{}, errInvalidToken } token, err := oidcMgr.getToken(cookie.Value) if err != nil { logger.Debug(logSender, "", "error getting oidc token associated with cookie %q: %v", cookie.Value, err) doRedirect() return oidcToken{}, errInvalidToken } if token.isExpired() { logger.Debug(logSender, "", "oidc token associated with cookie %q is expired", token.Cookie) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() if err = token.refresh(ctx, s.binding.OIDC.oauth2Config, s.binding.OIDC.getVerifier(ctx), r); err != nil { setFlashMessage(w, r, newFlashMessage("Your OpenID token is expired, please log-in again", util.I18nOIDCTokenExpired)) doRedirect() return oidcToken{}, errInvalidToken } } else { oidcMgr.updateTokenUsage(token) } if isAdmin { if !token.isAdmin() { logger.Debug(logSender, "", "oidc token associated with cookie %q is not valid for admin users", token.Cookie) setFlashMessage(w, r, newFlashMessage( "Your OpenID token is not valid for the SFTPGo Web Admin UI. Please logout from your OpenID server and log-in as an SFTPGo admin", util.I18nOIDCTokenInvalidAdmin, )) doRedirect() return oidcToken{}, errInvalidToken } return token, nil } if token.isAdmin() { logger.Debug(logSender, "", "oidc token associated with cookie %q is valid for admin users", token.Cookie) setFlashMessage(w, r, newFlashMessage( "Your OpenID token is not valid for the SFTPGo Web Client UI. Please logout from your OpenID server and log-in as an SFTPGo user", util.I18nOIDCTokenInvalidUser, )) doRedirect() return oidcToken{}, errInvalidToken } return token, nil } func (s *httpdServer) oidcTokenAuthenticator(audience tokenAudience) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if canSkipOIDCValidation(r) { next.ServeHTTP(w, r) return } token, err := s.validateOIDCToken(w, r, audience == tokenAudienceWebAdmin) if err != nil { return } claims := jwt.Claims{ Username: dataprovider.ConvertName(token.Username), Permissions: token.Permissions, Role: token.TokenRole, HideUserPageSections: token.HideUserPageSections, } claims.ID = token.Cookie if audience == tokenAudienceWebClient { claims.MustSetTwoFactorAuth = token.MustSetTwoFactorAuth claims.MustChangePassword = token.MustChangePassword claims.RequiredTwoFactorProtocols = token.RequiredTwoFactorProtocols } tokenString, err := s.tokenAuth.SignWithParams(&claims, audience, util.GetIPFromRemoteAddress(r.RemoteAddr), getTokenDuration(audience)) if err != nil { setFlashMessage(w, r, newFlashMessage("Unable to create cookie", util.I18nError500Message)) if audience == tokenAudienceWebAdmin { http.Redirect(w, r, webAdminLoginPath, http.StatusFound) } else { http.Redirect(w, r, webClientLoginPath, http.StatusFound) } return } ctx := context.WithValue(r.Context(), oidcTokenKey, token.Cookie) ctx = context.WithValue(ctx, oidcGeneratedToken, tokenString) next.ServeHTTP(w, r.WithContext(ctx)) }) } } func (s *httpdServer) handleWebAdminOIDCLogin(w http.ResponseWriter, r *http.Request) { s.oidcLoginRedirect(w, r, tokenAudienceWebAdmin) } func (s *httpdServer) handleWebClientOIDCLogin(w http.ResponseWriter, r *http.Request) { s.oidcLoginRedirect(w, r, tokenAudienceWebClient) } func (s *httpdServer) oidcLoginRedirect(w http.ResponseWriter, r *http.Request, audience tokenAudience) { pendingAuth := newOIDCPendingAuth(audience) oidcMgr.addPendingAuth(pendingAuth) http.Redirect(w, r, s.binding.OIDC.oauth2Config.AuthCodeURL(pendingAuth.State, oidc.Nonce(pendingAuth.Nonce), oauth2.S256ChallengeOption(pendingAuth.Verifier)), http.StatusFound) } func (s *httpdServer) debugTokenClaims(claims map[string]any, rawIDToken string) { if s.binding.OIDC.Debug { if claims == nil { logger.Debug(logSender, "", "raw id token %q", rawIDToken) } else { logger.Debug(logSender, "", "raw id token %q, parsed claims %+v", rawIDToken, claims) } } } func (s *httpdServer) handleOIDCRedirect(w http.ResponseWriter, r *http.Request) { state := r.URL.Query().Get("state") authReq, err := oidcMgr.getPendingAuth(state) if err != nil { logger.Debug(logSender, "", "oidc authentication state did not match") oidcMgr.removePendingAuth(state) s.renderClientMessagePage(w, r, util.I18nInvalidAuthReqTitle, http.StatusBadRequest, util.NewI18nError(err, util.I18nInvalidAuth), "") return } oidcMgr.removePendingAuth(state) doRedirect := func() { if authReq.Audience == tokenAudienceWebAdmin { http.Redirect(w, r, webAdminLoginPath, http.StatusFound) return } http.Redirect(w, r, webClientLoginPath, http.StatusFound) } doLogout := func(rawIDToken string) { s.logoutFromOIDCOP(rawIDToken) } ctx, cancel := context.WithTimeout(r.Context(), 20*time.Second) defer cancel() oauth2Token, err := s.binding.OIDC.oauth2Config.Exchange(ctx, r.URL.Query().Get("code"), oauth2.VerifierOption(authReq.Verifier)) if err != nil { logger.Debug(logSender, "", "failed to exchange oidc token: %v", err) setFlashMessage(w, r, newFlashMessage("Failed to exchange OpenID token", util.I18nOIDCErrTokenExchange)) doRedirect() return } rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { logger.Debug(logSender, "", "no id_token field in OAuth2 OpenID token") setFlashMessage(w, r, newFlashMessage("No id_token field in OAuth2 OpenID token", util.I18nOIDCTokenInvalid)) doRedirect() return } s.debugTokenClaims(nil, rawIDToken) idToken, err := s.binding.OIDC.getVerifier(ctx).Verify(ctx, rawIDToken) if err != nil { logger.Debug(logSender, "", "failed to verify oidc token: %v", err) setFlashMessage(w, r, newFlashMessage("Failed to verify OpenID token", util.I18nOIDCTokenInvalid)) doRedirect() doLogout(rawIDToken) return } if idToken.Nonce != authReq.Nonce { logger.Debug(logSender, "", "oidc authentication nonce did not match") setFlashMessage(w, r, newFlashMessage("OpenID authentication nonce did not match", util.I18nOIDCTokenInvalid)) doRedirect() doLogout(rawIDToken) return } claims := make(map[string]any) err = idToken.Claims(&claims) if err != nil { logger.Debug(logSender, "", "unable to get oidc token claims: %v", err) setFlashMessage(w, r, newFlashMessage("Unable to get OpenID token claims", util.I18nOIDCTokenInvalid)) doRedirect() doLogout(rawIDToken) return } s.debugTokenClaims(claims, rawIDToken) token := oidcToken{ AccessToken: oauth2Token.AccessToken, TokenType: oauth2Token.TokenType, RefreshToken: oauth2Token.RefreshToken, IDToken: rawIDToken, Nonce: idToken.Nonce, Cookie: util.GenerateOpaqueString(), } if !oauth2Token.Expiry.IsZero() { token.ExpiresAt = util.GetTimeAsMsSinceEpoch(oauth2Token.Expiry) } err = token.parseClaims(claims, s.binding.OIDC.UsernameField, s.binding.OIDC.RoleField, s.binding.OIDC.CustomFields, s.binding.OIDC.getForcedRole(authReq.Audience)) if err != nil { logger.Debug(logSender, "", "unable to parse oidc token claims: %v", err) setFlashMessage(w, r, newFlashMessage(fmt.Sprintf("Unable to parse OpenID token claims: %v", err), util.I18nOIDCTokenInvalid)) doRedirect() doLogout(rawIDToken) return } switch authReq.Audience { case tokenAudienceWebAdmin: if !token.isAdmin() { logger.Debug(logSender, "", "wrong oidc token role, the mapped user is not an SFTPGo admin") setFlashMessage(w, r, newFlashMessage( "Wrong OpenID role, the logged in user is not an SFTPGo admin", util.I18nOIDCTokenInvalidRoleAdmin)) doRedirect() doLogout(rawIDToken) return } case tokenAudienceWebClient: if token.isAdmin() { logger.Debug(logSender, "", "wrong oidc token role, the mapped user is an SFTPGo admin") setFlashMessage(w, r, newFlashMessage( "Wrong OpenID role, the logged in user is an SFTPGo admin", util.I18nOIDCTokenInvalidRoleUser, )) doRedirect() doLogout(rawIDToken) return } } err = token.getUser(r) if err != nil { logger.Debug(logSender, "", "unable to get the sftpgo user associated with oidc token: %v", err) setFlashMessage(w, r, newFlashMessage("Unable to get the user associated with the OpenID token", util.I18nOIDCErrGetUser)) doRedirect() doLogout(rawIDToken) return } loginOIDCUser(w, r, token) } func loginOIDCUser(w http.ResponseWriter, r *http.Request, token oidcToken) { oidcMgr.addToken(token) cookie := http.Cookie{ Name: oidcCookieKey, Value: token.Cookie, Path: "/", HttpOnly: true, Secure: isTLS(r), SameSite: http.SameSiteLaxMode, } // we don't set a cookie expiration so we can refresh the token without setting a new cookie // the cookie will be invalidated on browser close http.SetCookie(w, &cookie) w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) if token.isAdmin() { http.Redirect(w, r, webUsersPath, http.StatusFound) return } http.Redirect(w, r, webClientFilesPath, http.StatusFound) } func (s *httpdServer) logoutOIDCUser(w http.ResponseWriter, r *http.Request) { if oidcKey, ok := r.Context().Value(oidcTokenKey).(string); ok { removeOIDCCookie(w, r) token, err := oidcMgr.getToken(oidcKey) if err == nil { s.logoutFromOIDCOP(token.IDToken) } oidcMgr.removeToken(oidcKey) } } func (s *httpdServer) logoutFromOIDCOP(idToken string) { if s.binding.OIDC.providerLogoutURL == "" { logger.Debug(logSender, "", "oidc: provider logout URL not set, unable to logout from the OP") return } go s.doOIDCFromLogout(idToken) } func (s *httpdServer) doOIDCFromLogout(idToken string) { logoutURL, err := url.Parse(s.binding.OIDC.providerLogoutURL) if err != nil { logger.Warn(logSender, "", "oidc: unable to parse logout URL: %v", err) return } query := logoutURL.Query() if idToken != "" { query.Set("id_token_hint", idToken) } logoutURL.RawQuery = query.Encode() resp, err := httpclient.RetryableGet(logoutURL.String()) if err != nil { logger.Warn(logSender, "", "oidc: error calling logout URL %q: %v", logoutURL.String(), err) return } defer resp.Body.Close() logger.Debug(logSender, "", "oidc: logout url response code %v", resp.StatusCode) } func removeOIDCCookie(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: oidcCookieKey, Value: "", Path: "/", Expires: time.Unix(0, 0), MaxAge: -1, HttpOnly: true, Secure: isTLS(r), SameSite: http.SameSiteLaxMode, }) } // canSkipOIDCValidation returns true if there is no OIDC cookie but a jwt cookie is set // and so we check if the user is logged in using a built-in user func canSkipOIDCValidation(r *http.Request) bool { _, err := r.Cookie(oidcCookieKey) if err != nil { _, err = r.Cookie(jwt.CookieKey) return err == nil } return false } func isLoggedInWithOIDC(r *http.Request) bool { _, ok := r.Context().Value(oidcTokenKey).(string) return ok } func getOIDCFieldFromClaims(claims map[string]any, fieldName string) (any, bool) { if fieldName == "" { return nil, false } val, ok := claims[fieldName] if ok { return val, true } if !strings.Contains(fieldName, ".") { return nil, false } getStructValue := func(outer any, field string) (any, bool) { switch v := outer.(type) { case map[string]any: res, ok := v[field] return res, ok } return nil, false } for idx, field := range strings.Split(fieldName, ".") { if idx == 0 { val, ok = getStructValue(claims, field) } else { val, ok = getStructValue(val, field) } if !ok { return nil, false } } return val, ok } ================================================ FILE: internal/httpd/oidc_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "bytes" "context" "encoding/json" "fmt" "io/fs" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "reflect" "runtime" "testing" "time" "unsafe" "github.com/coreos/go-oidc/v3/oidc" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( oidcMockAddr = "127.0.0.1:11111" ) type mockTokenSource struct { token *oauth2.Token err error } func (t *mockTokenSource) Token() (*oauth2.Token, error) { return t.token, t.err } type mockOAuth2Config struct { tokenSource *mockTokenSource authCodeURL string token *oauth2.Token err error } func (c *mockOAuth2Config) AuthCodeURL(_ string, _ ...oauth2.AuthCodeOption) string { return c.authCodeURL } func (c *mockOAuth2Config) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { return c.token, c.err } func (c *mockOAuth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource { return c.tokenSource } type mockOIDCVerifier struct { token *oidc.IDToken err error } func (v *mockOIDCVerifier) Verify(_ context.Context, _ string) (*oidc.IDToken, error) { return v.token, v.err } // hack because the field is unexported func setIDTokenClaims(idToken *oidc.IDToken, claims []byte) { pointerVal := reflect.ValueOf(idToken) val := reflect.Indirect(pointerVal) member := val.FieldByName("claims") ptr := unsafe.Pointer(member.UnsafeAddr()) realPtr := (*[]byte)(ptr) *realPtr = claims } func TestOIDCInitialization(t *testing.T) { config := OIDC{} err := config.initialize() assert.NoError(t, err) secret := "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c" config = OIDC{ ClientID: "sftpgo-client", ClientSecret: util.GenerateUniqueID(), ConfigURL: fmt.Sprintf("http://%v/", oidcMockAddr), RedirectBaseURL: "http://127.0.0.1:8081/", UsernameField: "preferred_username", RoleField: "sftpgo_role", } err = config.initialize() if assert.Error(t, err) { assert.Contains(t, err.Error(), "oidc: required scope \"openid\" is not set") } config.Scopes = []string{oidc.ScopeOpenID} config.ClientSecretFile = "missing file" err = config.initialize() assert.ErrorIs(t, err, fs.ErrNotExist) secretFile := filepath.Join(os.TempDir(), util.GenerateUniqueID()) defer os.Remove(secretFile) err = os.WriteFile(secretFile, []byte(secret), 0600) assert.NoError(t, err) config.ClientSecretFile = secretFile err = config.initialize() if assert.Error(t, err) { assert.Contains(t, err.Error(), "oidc: unable to initialize provider") } assert.Equal(t, secret, config.ClientSecret) config.ConfigURL = fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr) err = config.initialize() assert.NoError(t, err) assert.Equal(t, "http://127.0.0.1:8081"+webOIDCRedirectPath, config.getRedirectURL()) } func TestOIDCLoginLogout(t *testing.T) { tokenValidationMode = 2 oidcMgr, ok := oidcMgr.(*memoryOIDCManager) require.True(t, ok) server := getTestOIDCServer() err := server.binding.OIDC.initialize() assert.NoError(t, err) err = server.initializeRouter() require.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth) expiredAuthReq := oidcPendingAuth{ State: util.GenerateOpaqueString(), Nonce: util.GenerateOpaqueString(), Audience: tokenAudienceWebClient, IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), } oidcMgr.addPendingAuth(expiredAuthReq) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+expiredAuthReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth) oidcMgr.removePendingAuth(expiredAuthReq.State) server.binding.OIDC.oauth2Config = &mockOAuth2Config{ tokenSource: &mockTokenSource{}, authCodeURL: webOIDCRedirectPath, err: common.ErrGenericFailure, } server.binding.OIDC.verifier = &mockOIDCVerifier{ err: common.ErrGenericFailure, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webAdminOIDCLoginPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 1) var state string for k := range oidcMgr.pendingAuths { state = k } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) // now the same for the web client rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientOIDCLoginPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 1) for k := range oidcMgr.pendingAuths { state = k } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) // now return an OAuth2 token without the id_token server.binding.OIDC.oauth2Config = &mockOAuth2Config{ tokenSource: &mockTokenSource{}, authCodeURL: webOIDCRedirectPath, token: &oauth2.Token{ AccessToken: "123", Expiry: time.Now().Add(5 * time.Minute), }, err: nil, } authReq := newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // now fail to verify the id token token := &oauth2.Token{ AccessToken: "123", Expiry: time.Now().Add(5 * time.Minute), } token = token.WithExtra(map[string]any{ "id_token": "id_token_val", }) server.binding.OIDC.oauth2Config = &mockOAuth2Config{ tokenSource: &mockTokenSource{}, authCodeURL: webOIDCRedirectPath, token: token, err: nil, } authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // id token nonce does not match server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: &oidc.IDToken{}, } authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // null id token claims authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: &oidc.IDToken{ Nonce: authReq.Nonce, }, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // invalid id token claims: no username authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) idToken := &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // invalid id token clamims: username not a string authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id","preferred_username": 1}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // invalid audience authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // invalid audience authReq = newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"test"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // mapped user not found authReq = newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) // admin login ok authReq = newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sftpgo_role":"admin","sid":"sid123"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webUsersPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 1) // admin profile is not available var tokenCookie string for k := range oidcMgr.tokens { tokenCookie = k } oidcToken, err := oidcMgr.getToken(tokenCookie) assert.NoError(t, err) assert.Equal(t, "sid123", oidcToken.SessionID) assert.True(t, oidcToken.isAdmin()) assert.False(t, oidcToken.isExpired()) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusForbidden, rr.Code) // the admin can access the allowed pages rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) // try with an invalid cookie rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String())) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) // Web Client is not available with an admin token rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) // logout the admin user rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 0) // now login and logout a user username := "test_oidc_user" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: "pwd", HomeDir: filepath.Join(os.TempDir(), username), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, Filters: dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ WebClient: []string{sdk.WebClientSharesDisabled}, }, }, } err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_user"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 1) // user profile is not available for k := range oidcMgr.tokens { tokenCookie = k } oidcToken, err = oidcMgr.getToken(tokenCookie) assert.NoError(t, err) assert.Empty(t, oidcToken.SessionID) assert.False(t, oidcToken.isAdmin()) assert.False(t, oidcToken.isExpired()) if assert.Len(t, oidcToken.Permissions, 1) { assert.Equal(t, sdk.WebClientSharesDisabled, oidcToken.Permissions[0]) } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil) assert.NoError(t, err) r.RequestURI = webClientProfilePath r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) // the user can access the allowed pages rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) // try with an invalid cookie rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String())) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) // Web Admin is not available with a client cookie rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) // logout the user rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 0) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) tokenValidationMode = 0 } func TestOIDCRefreshToken(t *testing.T) { oidcMgr, ok := oidcMgr.(*memoryOIDCManager) require.True(t, ok) r, err := http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) token := oidcToken{ Cookie: util.GenerateOpaqueString(), AccessToken: xid.New().String(), TokenType: "Bearer", ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)), Nonce: xid.New().String(), Role: adminRoleFieldValue, Username: defaultAdminUsername, } config := mockOAuth2Config{ tokenSource: &mockTokenSource{ err: common.ErrGenericFailure, }, } verifier := mockOIDCVerifier{ err: common.ErrGenericFailure, } err = token.refresh(context.Background(), &config, &verifier, r) if assert.Error(t, err) { assert.Contains(t, err.Error(), "refresh token not set") } token.RefreshToken = xid.New().String() err = token.refresh(context.Background(), &config, &verifier, r) assert.ErrorIs(t, err, common.ErrGenericFailure) newToken := &oauth2.Token{ AccessToken: xid.New().String(), RefreshToken: xid.New().String(), Expiry: time.Now().Add(5 * time.Minute), } config = mockOAuth2Config{ tokenSource: &mockTokenSource{ token: newToken, }, } verifier = mockOIDCVerifier{ token: &oidc.IDToken{}, } err = token.refresh(context.Background(), &config, &verifier, r) if assert.Error(t, err) { assert.Contains(t, err.Error(), "the refreshed token has no id token") } newToken = newToken.WithExtra(map[string]any{ "id_token": "id_token_val", }) newToken.Expiry = time.Time{} config = mockOAuth2Config{ tokenSource: &mockTokenSource{ token: newToken, }, } verifier = mockOIDCVerifier{ err: common.ErrGenericFailure, } err = token.refresh(context.Background(), &config, &verifier, r) assert.ErrorIs(t, err, common.ErrGenericFailure) newToken = newToken.WithExtra(map[string]any{ "id_token": "id_token_val", }) newToken.Expiry = time.Now().Add(5 * time.Minute) config = mockOAuth2Config{ tokenSource: &mockTokenSource{ token: newToken, }, } verifier = mockOIDCVerifier{ token: &oidc.IDToken{ Nonce: xid.New().String(), // nonce is different from the expected one }, } err = token.refresh(context.Background(), &config, &verifier, r) if assert.Error(t, err) { assert.Contains(t, err.Error(), "the refreshed token nonce mismatch") } verifier = mockOIDCVerifier{ token: &oidc.IDToken{ Nonce: "", // empty token is fine on refresh but claims are not set }, } err = token.refresh(context.Background(), &config, &verifier, r) if assert.Error(t, err) { assert.Contains(t, err.Error(), "oidc: claims not set") } idToken := &oidc.IDToken{ Nonce: token.Nonce, } setIDTokenClaims(idToken, []byte(`{"sid":"id_token_sid"}`)) verifier = mockOIDCVerifier{ token: idToken, } err = token.refresh(context.Background(), &config, &verifier, r) assert.NoError(t, err) assert.Len(t, token.Permissions, 1) token.Role = nil // user does not exist err = token.refresh(context.Background(), &config, &verifier, r) assert.Error(t, err) require.Len(t, oidcMgr.tokens, 1) oidcMgr.removeToken(token.Cookie) require.Len(t, oidcMgr.tokens, 0) } func TestOIDCRefreshUser(t *testing.T) { token := oidcToken{ Cookie: util.GenerateOpaqueString(), AccessToken: xid.New().String(), TokenType: "Bearer", ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)), Nonce: xid.New().String(), Role: adminRoleFieldValue, Username: "missing username", } r, err := http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) err = token.refreshUser(r) assert.Error(t, err) admin := dataprovider.Admin{ Username: "test_oidc_admin_refresh", Password: "p", Permissions: []string{dataprovider.PermAdminAny}, Status: 0, Filters: dataprovider.AdminFilters{ Preferences: dataprovider.AdminPreferences{ HideUserPageSections: 1 + 2 + 4, }, }, } err = dataprovider.AddAdmin(&admin, "", "", "") assert.NoError(t, err) token.Username = admin.Username err = token.refreshUser(r) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is disabled") } admin.Status = 1 err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) err = token.refreshUser(r) assert.NoError(t, err) assert.Equal(t, admin.Permissions, token.Permissions) assert.Equal(t, admin.Filters.Preferences.HideUserPageSections, token.HideUserPageSections) err = dataprovider.DeleteAdmin(admin.Username, "", "", "") assert.NoError(t, err) username := "test_oidc_user_refresh_token" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: "p", HomeDir: filepath.Join(os.TempDir(), username), Status: 0, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, Filters: dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ DeniedProtocols: []string{common.ProtocolHTTP}, WebClient: []string{sdk.WebClientSharesDisabled, sdk.WebClientWriteDisabled}, }, }, } err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) token.Role = nil token.Username = username assert.False(t, token.isAdmin()) err = token.refreshUser(r) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is disabled") } user, err = dataprovider.UserExists(username, "") assert.NoError(t, err) user.Status = 1 err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) err = token.refreshUser(r) if assert.Error(t, err) { assert.Contains(t, err.Error(), "protocol HTTP is not allowed") } user.Filters.DeniedProtocols = []string{common.ProtocolFTP} err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) err = token.refreshUser(r) assert.NoError(t, err) assert.Equal(t, user.Filters.WebClient, token.Permissions) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) } func TestValidateOIDCToken(t *testing.T) { oidcMgr, ok := oidcMgr.(*memoryOIDCManager) require.True(t, ok) server := getTestOIDCServer() err := server.binding.OIDC.initialize() assert.NoError(t, err) err = server.initializeRouter() require.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil) assert.NoError(t, err) _, err = server.validateOIDCToken(rr, r, false) assert.ErrorIs(t, err, errInvalidToken) // expired token and refresh error server.binding.OIDC.oauth2Config = &mockOAuth2Config{ tokenSource: &mockTokenSource{ err: common.ErrGenericFailure, }, } token := oidcToken{ Cookie: util.GenerateOpaqueString(), AccessToken: xid.New().String(), ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)), } oidcMgr.addToken(token) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie)) _, err = server.validateOIDCToken(rr, r, false) assert.ErrorIs(t, err, errInvalidToken) oidcMgr.removeToken(token.Cookie) assert.Len(t, oidcMgr.tokens, 0) server.tokenAuth.SetSigner(&failingJoseSigner{}) token = oidcToken{ Cookie: util.GenerateOpaqueString(), AccessToken: util.GenerateUniqueID(), } oidcMgr.addToken(token) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) oidcMgr.removeToken(token.Cookie) assert.Len(t, oidcMgr.tokens, 0) token = oidcToken{ Cookie: util.GenerateOpaqueString(), AccessToken: xid.New().String(), Role: "admin", } oidcMgr.addToken(token) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) oidcMgr.removeToken(token.Cookie) assert.Len(t, oidcMgr.tokens, 0) } func TestSkipOIDCAuth(t *testing.T) { server := getTestOIDCServer() err := server.binding.OIDC.initialize() assert.NoError(t, err) err = server.initializeRouter() require.NoError(t, err) claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) claims.Username = "user" tokenString, err := server.tokenAuth.Sign(claims) assert.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwt.CookieKey, tokenString)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) } func TestOIDCLogoutErrors(t *testing.T) { server := getTestOIDCServer() assert.Empty(t, server.binding.OIDC.providerLogoutURL) server.logoutFromOIDCOP("") server.binding.OIDC.providerLogoutURL = "http://foo\x7f.com/" server.doOIDCFromLogout("") server.binding.OIDC.providerLogoutURL = "http://127.0.0.1:11234" server.doOIDCFromLogout("") } func TestOIDCToken(t *testing.T) { admin := dataprovider.Admin{ Username: "test_oidc_admin", Password: "p", Permissions: []string{dataprovider.PermAdminAny}, Status: 0, } err := dataprovider.AddAdmin(&admin, "", "", "") assert.NoError(t, err) token := oidcToken{ Username: admin.Username, } // role not initialized, user with the specified username does not exist req, err := http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) err = token.getUser(req) assert.ErrorIs(t, err, util.ErrNotFound) token.Role = "admin" req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) assert.NoError(t, err) err = token.getUser(req) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is disabled") } err = dataprovider.DeleteAdmin(admin.Username, "", "", "") assert.NoError(t, err) username := "test_oidc_user" token.Username = username token.Role = "" err = token.getUser(req) if assert.Error(t, err) { assert.ErrorIs(t, err, util.ErrNotFound) } user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: "p", HomeDir: filepath.Join(os.TempDir(), username), Status: 0, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, Filters: dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ DeniedProtocols: []string{common.ProtocolHTTP}, DeniedLoginMethods: []string{dataprovider.LoginMethodPassword}, }, }, } err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) err = token.getUser(req) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is disabled") } user, err = dataprovider.UserExists(username, "") assert.NoError(t, err) user.Status = 1 user.Password = "np" err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) err = token.getUser(req) if assert.Error(t, err) { assert.Contains(t, err.Error(), "protocol HTTP is not allowed") } user.Filters.DeniedProtocols = nil user.FsConfig.Provider = sdk.SFTPFilesystemProvider user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: "127.0.0.1:8022", Username: username, }, Password: kms.NewPlainSecret("np"), } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) err = token.getUser(req) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SFTP loop") } common.Config.PostConnectHook = fmt.Sprintf("http://%v/404", oidcMockAddr) err = token.getUser(req) if assert.Error(t, err) { assert.Contains(t, err.Error(), "access denied") } common.Config.PostConnectHook = "" err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) } func TestOIDCImplicitRoles(t *testing.T) { oidcMgr, ok := oidcMgr.(*memoryOIDCManager) require.True(t, ok) server := getTestOIDCServer() server.binding.OIDC.ImplicitRoles = true err := server.binding.OIDC.initialize() assert.NoError(t, err) err = server.initializeRouter() require.NoError(t, err) authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) token := &oauth2.Token{ AccessToken: "1234", Expiry: time.Now().Add(5 * time.Minute), } token = token.WithExtra(map[string]any{ "id_token": "id_token_val", }) server.binding.OIDC.oauth2Config = &mockOAuth2Config{ tokenSource: &mockTokenSource{}, authCodeURL: webOIDCRedirectPath, token: token, } idToken := &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sid":"sid456"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webUsersPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 1) var tokenCookie string for k := range oidcMgr.tokens { tokenCookie = k } // Web Client is not available with an admin token rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) // logout the admin user rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 0) // now login and logout a user username := "test_oidc_implicit_user" user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: "pwd", HomeDir: filepath.Join(os.TempDir(), username), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, Filters: dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ WebClient: []string{sdk.WebClientSharesDisabled}, }, }, } err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_implicit_user"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 1) for k := range oidcMgr.tokens { tokenCookie = k } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 0) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) } func TestMemoryOIDCManager(t *testing.T) { oidcMgr, ok := oidcMgr.(*memoryOIDCManager) require.True(t, ok) require.Len(t, oidcMgr.pendingAuths, 0) authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) require.Len(t, oidcMgr.pendingAuths, 1) _, err := oidcMgr.getPendingAuth(authReq.State) assert.NoError(t, err) oidcMgr.removePendingAuth(authReq.State) require.Len(t, oidcMgr.pendingAuths, 0) authReq.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-600 * time.Second)) oidcMgr.addPendingAuth(authReq) require.Len(t, oidcMgr.pendingAuths, 1) _, err = oidcMgr.getPendingAuth(authReq.State) if assert.Error(t, err) { assert.Contains(t, err.Error(), "too old") } oidcMgr.cleanup() require.Len(t, oidcMgr.pendingAuths, 0) token := oidcToken{ AccessToken: xid.New().String(), Nonce: xid.New().String(), SessionID: xid.New().String(), Cookie: util.GenerateOpaqueString(), Username: xid.New().String(), Role: "admin", Permissions: []string{dataprovider.PermAdminAny}, } require.Len(t, oidcMgr.tokens, 0) oidcMgr.addToken(token) require.Len(t, oidcMgr.tokens, 1) _, err = oidcMgr.getToken(xid.New().String()) assert.Error(t, err) storedToken, err := oidcMgr.getToken(token.Cookie) assert.NoError(t, err) token.UsedAt = 0 // ensure we don't modify the stored token assert.Greater(t, storedToken.UsedAt, int64(0)) token.UsedAt = storedToken.UsedAt assert.Equal(t, token, storedToken) // the usage will not be updated, it is recent oidcMgr.updateTokenUsage(storedToken) storedToken, err = oidcMgr.getToken(token.Cookie) assert.NoError(t, err) assert.Equal(t, token, storedToken) usedAt := util.GetTimeAsMsSinceEpoch(time.Now().Add(-5 * time.Minute)) storedToken.UsedAt = usedAt oidcMgr.tokens[token.Cookie] = storedToken storedToken, err = oidcMgr.getToken(token.Cookie) assert.NoError(t, err) assert.Equal(t, usedAt, storedToken.UsedAt) token.UsedAt = storedToken.UsedAt assert.Equal(t, token, storedToken) oidcMgr.updateTokenUsage(storedToken) storedToken, err = oidcMgr.getToken(token.Cookie) assert.NoError(t, err) assert.Greater(t, storedToken.UsedAt, usedAt) token.UsedAt = storedToken.UsedAt assert.Equal(t, token, storedToken) storedToken.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - tokenDeleteInterval - 1 oidcMgr.tokens[token.Cookie] = storedToken storedToken, err = oidcMgr.getToken(token.Cookie) if assert.Error(t, err) { assert.Contains(t, err.Error(), "token is too old") } oidcMgr.removeToken(xid.New().String()) require.Len(t, oidcMgr.tokens, 1) oidcMgr.removeToken(token.Cookie) require.Len(t, oidcMgr.tokens, 0) oidcMgr.addToken(token) usedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-6 * time.Hour)) token.UsedAt = usedAt oidcMgr.tokens[token.Cookie] = token newToken := oidcToken{ Cookie: util.GenerateOpaqueString(), } oidcMgr.addToken(newToken) oidcMgr.cleanup() require.Len(t, oidcMgr.tokens, 1) _, err = oidcMgr.getToken(token.Cookie) assert.Error(t, err) _, err = oidcMgr.getToken(newToken.Cookie) assert.NoError(t, err) oidcMgr.removeToken(newToken.Cookie) require.Len(t, oidcMgr.tokens, 0) } func TestOIDCEvMgrIntegration(t *testing.T) { providerConf := dataprovider.GetProviderConfig() err := dataprovider.Close() assert.NoError(t, err) newProviderConf := providerConf newProviderConf.NamingRules = 5 err = dataprovider.Initialize(newProviderConf, configDir, true) assert.NoError(t, err) // add a special chars to check json replacer username := `test_'oidc_eventmanager` u := map[string]any{ "username": "{{.Name}}", "status": 1, "home_dir": filepath.Join(os.TempDir(), "{{.IDPFieldcustom1.sub}}"), "permissions": map[string][]string{ "/": {dataprovider.PermAny}, }, "description": "{{.IDPFieldcustom2}}", } userTmpl, err := json.Marshal(u) require.NoError(t, err) a := map[string]any{ "username": "{{.Name}}", "status": 1, "permissions": []string{dataprovider.PermAdminAny}, } adminTmpl, err := json.Marshal(a) require.NoError(t, err) action := &dataprovider.BaseEventAction{ Name: "a", Type: dataprovider.ActionTypeIDPAccountCheck, Options: dataprovider.BaseEventActionOptions{ IDPConfig: dataprovider.EventActionIDPAccountCheck{ Mode: 0, TemplateUser: string(userTmpl), TemplateAdmin: string(adminTmpl), }, }, } err = dataprovider.AddEventAction(action, "", "", "") assert.NoError(t, err) rule := &dataprovider.EventRule{ Name: "r", Status: 1, Trigger: dataprovider.EventTriggerIDPLogin, Conditions: dataprovider.EventConditions{ IDPLoginEvent: 0, }, Actions: []dataprovider.EventAction{ { BaseEventAction: dataprovider.BaseEventAction{ Name: action.Name, }, Options: dataprovider.EventActionOptions{ ExecuteSync: true, }, }, }, } err = dataprovider.AddEventRule(rule, "", "", "") assert.NoError(t, err) oidcMgr, ok := oidcMgr.(*memoryOIDCManager) require.True(t, ok) server := getTestOIDCServer() server.binding.OIDC.ImplicitRoles = true server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"} err = server.binding.OIDC.initialize() assert.NoError(t, err) err = server.initializeRouter() require.NoError(t, err) // login a user with OIDC _, err = dataprovider.UserExists(username, "") assert.ErrorIs(t, err, util.ErrNotFound) authReq := newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) token := &oauth2.Token{ AccessToken: "1234", Expiry: time.Now().Add(5 * time.Minute), } token = token.WithExtra(map[string]any{ "id_token": "id_token_val", }) server.binding.OIDC.oauth2Config = &mockOAuth2Config{ tokenSource: &mockTokenSource{}, authCodeURL: webOIDCRedirectPath, token: token, } idToken := &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":{"sub":"val1"},"custom2":"desc"}`)) //nolint:goconst server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) user, err := dataprovider.UserExists(username, "") assert.NoError(t, err) assert.Equal(t, filepath.Join(os.TempDir(), "val1"), user.GetHomeDir()) assert.Equal(t, "desc", user.Description) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) // login an admin with OIDC _, err = dataprovider.AdminExists(username) assert.ErrorIs(t, err, util.ErrNotFound) authReq = newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webUsersPath, rr.Header().Get("Location")) _, err = dataprovider.AdminExists(username) assert.NoError(t, err) err = dataprovider.DeleteAdmin(username, "", "", "") assert.NoError(t, err) // set invalid templates and try again action.Options.IDPConfig.TemplateUser = `{}` action.Options.IDPConfig.TemplateAdmin = `{}` err = dataprovider.UpdateEventAction(action, "", "", "") assert.NoError(t, err) for _, audience := range []string{tokenAudienceWebAdmin, tokenAudienceWebClient} { authReq = newOIDCPendingAuth(audience) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) } for k := range oidcMgr.tokens { oidcMgr.removeToken(k) } err = dataprovider.DeleteEventRule(rule.Name, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteEventAction(action.Name, "", "", "") assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestOIDCPreLoginHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oidcMgr, ok := oidcMgr.(*memoryOIDCManager) require.True(t, ok) username := "test_oidc_user_prelogin" u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, HomeDir: filepath.Join(os.TempDir(), username), Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } preLoginPath := filepath.Join(os.TempDir(), "prelogin.sh") providerConf := dataprovider.GetProviderConfig() err := dataprovider.Close() assert.NoError(t, err) err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) newProviderConf := providerConf newProviderConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(newProviderConf, configDir, true) assert.NoError(t, err) server := getTestOIDCServer() server.binding.OIDC.CustomFields = []string{"field1", "field2"} err = server.binding.OIDC.initialize() assert.NoError(t, err) err = server.initializeRouter() require.NoError(t, err) _, err = dataprovider.UserExists(username, "") assert.ErrorIs(t, err, util.ErrNotFound) // now login with OIDC authReq := newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) token := &oauth2.Token{ AccessToken: "1234", Expiry: time.Now().Add(5 * time.Minute), } token = token.WithExtra(map[string]any{ "id_token": "id_token_val", }) server.binding.OIDC.oauth2Config = &mockOAuth2Config{ tokenSource: &mockTokenSource{}, authCodeURL: webOIDCRedirectPath, token: token, } idToken := &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) _, err = dataprovider.UserExists(username, "") assert.NoError(t, err) err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(u.HomeDir) assert.NoError(t, err) err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, true), os.ModePerm) assert.NoError(t, err) authReq = newOIDCPendingAuth(tokenAudienceWebClient) oidcMgr.addPendingAuth(authReq) idToken = &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`","field1":"value1","field2":"value2","field3":"value3"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) _, err = dataprovider.UserExists(username, "") assert.ErrorIs(t, err, util.ErrNotFound) if assert.Len(t, oidcMgr.tokens, 1) { for k := range oidcMgr.tokens { oidcMgr.removeToken(k) } } require.Len(t, oidcMgr.pendingAuths, 0) require.Len(t, oidcMgr.tokens, 0) err = dataprovider.Close() assert.NoError(t, err) err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestOIDCIsAdmin(t *testing.T) { type test struct { input any want bool } emptySlice := make([]any, 0) tests := []test{ {input: "admin", want: true}, {input: append(emptySlice, "admin"), want: true}, {input: append(emptySlice, "user", "admin"), want: true}, {input: "user", want: false}, {input: emptySlice, want: false}, {input: append(emptySlice, 1), want: false}, {input: 1, want: false}, {input: nil, want: false}, {input: map[string]string{"admin": "admin"}, want: false}, } for _, tc := range tests { token := oidcToken{ Role: tc.input, } assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want) } } func TestParseAdminRole(t *testing.T) { claims := make(map[string]any) rawClaims := []byte(`{ "sub": "35666371", "email": "example@example.com", "preferred_username": "Sally", "name": "Sally Tyler", "updated_at": "2018-04-13T22:08:45Z", "given_name": "Sally", "family_name": "Tyler", "params": { "sftpgo_role": "admin", "subparams": { "sftpgo_role": "admin", "inner": { "sftpgo_role": ["user","admin"] } } }, "at_hash": "lPLhxI2wjEndc-WfyroDZA", "rt_hash": "mCmxPtA04N-55AxlEUbq-A", "aud": "78d1d040-20c9-0136-5146-067351775fae92920", "exp": 1523664997, "iat": 1523657797 }`) err := json.Unmarshal(rawClaims, &claims) assert.NoError(t, err) type test struct { input string want bool val any } tests := []test{ {input: "", want: false}, {input: "sftpgo_role", want: false}, {input: "params.sftpgo_role", want: true, val: "admin"}, {input: "params.subparams.sftpgo_role", want: true, val: "admin"}, {input: "params.subparams.inner.sftpgo_role", want: true, val: []any{"user", "admin"}}, {input: "email", want: false}, {input: "missing", want: false}, {input: "params.email", want: false}, {input: "missing.sftpgo_role", want: false}, {input: "params", want: false}, {input: "params.subparams.inner.sftpgo_role.missing", want: false}, } for _, tc := range tests { token := oidcToken{} token.getRoleFromField(claims, tc.input) assert.Equal(t, tc.want, token.isAdmin(), "%q should return %t", tc.input, tc.want) if tc.want { assert.Equal(t, tc.val, token.Role) } } } func TestOIDCWithLoginFormsDisabled(t *testing.T) { oidcMgr, ok := oidcMgr.(*memoryOIDCManager) require.True(t, ok) server := getTestOIDCServer() server.binding.OIDC.ImplicitRoles = true server.binding.DisabledLoginMethods = 12 server.binding.EnableWebAdmin = true server.binding.EnableWebClient = true err := server.binding.OIDC.initialize() assert.NoError(t, err) err = server.initializeRouter() require.NoError(t, err) // login with an admin user authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) token := &oauth2.Token{ AccessToken: "1234", Expiry: time.Now().Add(5 * time.Minute), } token = token.WithExtra(map[string]any{ "id_token": "id_token_val", }) server.binding.OIDC.oauth2Config = &mockOAuth2Config{ tokenSource: &mockTokenSource{}, authCodeURL: webOIDCRedirectPath, token: token, } idToken := &oidc.IDToken{ Nonce: authReq.Nonce, Expiry: time.Now().Add(5 * time.Minute), } setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sid":"sid456"}`)) server.binding.OIDC.verifier = &mockOIDCVerifier{ err: nil, token: idToken, } rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webUsersPath, rr.Header().Get("Location")) var tokenCookie string for k := range oidcMgr.tokens { tokenCookie = k } // we should be able to create admins without setting a password adminUsername := "testAdmin" form := make(url.Values) form.Set(csrfFormToken, createCSRFToken(rr, r, server.csrfTokenAuth, tokenCookie, webBaseAdminPath)) form.Set("username", adminUsername) form.Set("password", "") form.Set("status", "1") form.Set("permissions", "*") rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusSeeOther, rr.Code) _, err = dataprovider.AdminExists(adminUsername) assert.NoError(t, err) err = dataprovider.DeleteAdmin(adminUsername, "", "", "") assert.NoError(t, err) // login and password related routes are disabled rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminLoginPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusNotFound, rr.Code) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webClientLoginPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusNotFound, rr.Code) } func TestDbOIDCManager(t *testing.T) { if !isSharedProviderSupported() { t.Skip("this test it is not available with this provider") } mgr := newOIDCManager(1) pendingAuth := newOIDCPendingAuth(tokenAudienceWebAdmin) mgr.addPendingAuth(pendingAuth) authReq, err := mgr.getPendingAuth(pendingAuth.State) assert.NoError(t, err) assert.Equal(t, pendingAuth, authReq) pendingAuth.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) mgr.addPendingAuth(pendingAuth) _, err = mgr.getPendingAuth(pendingAuth.State) if assert.Error(t, err) { assert.Contains(t, err.Error(), "auth request is too old") } mgr.removePendingAuth(pendingAuth.State) _, err = mgr.getPendingAuth(pendingAuth.State) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to get the auth request for the specified state") } mgr.addPendingAuth(pendingAuth) _, err = mgr.getPendingAuth(pendingAuth.State) if assert.Error(t, err) { assert.Contains(t, err.Error(), "auth request is too old") } mgr.cleanup() _, err = mgr.getPendingAuth(pendingAuth.State) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to get the auth request for the specified state") } token := oidcToken{ Cookie: util.GenerateOpaqueString(), AccessToken: xid.New().String(), TokenType: "Bearer", RefreshToken: xid.New().String(), ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)), SessionID: xid.New().String(), IDToken: xid.New().String(), Nonce: xid.New().String(), Username: xid.New().String(), Permissions: []string{dataprovider.PermAdminAny}, Role: "admin", } mgr.addToken(token) tokenGet, err := mgr.getToken(token.Cookie) assert.NoError(t, err) assert.Greater(t, tokenGet.UsedAt, int64(0)) token.UsedAt = tokenGet.UsedAt assert.Equal(t, token, tokenGet) time.Sleep(100 * time.Millisecond) mgr.updateTokenUsage(token) // no change tokenGet, err = mgr.getToken(token.Cookie) assert.NoError(t, err) assert.Equal(t, token.UsedAt, tokenGet.UsedAt) tokenGet.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) tokenGet.RefreshToken = xid.New().String() mgr.updateTokenUsage(tokenGet) tokenGet, err = mgr.getToken(token.Cookie) assert.NoError(t, err) assert.NotEmpty(t, tokenGet.RefreshToken) assert.NotEqual(t, token.RefreshToken, tokenGet.RefreshToken) assert.Greater(t, tokenGet.UsedAt, token.UsedAt) mgr.removeToken(token.Cookie) tokenGet, err = mgr.getToken(token.Cookie) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to get the token for the specified session") } // add an expired token token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) session := dataprovider.Session{ Key: token.Cookie, Data: token, Type: dataprovider.SessionTypeOIDCToken, Timestamp: token.UsedAt + tokenDeleteInterval, } err = dataprovider.AddSharedSession(session) assert.NoError(t, err) _, err = mgr.getToken(token.Cookie) if assert.Error(t, err) { assert.Contains(t, err.Error(), "token is too old") } mgr.cleanup() _, err = mgr.getToken(token.Cookie) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to get the token for the specified session") } // adding a session without a key should fail session.Key = "" err = dataprovider.AddSharedSession(session) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to save a session with an empty key") } session.Key = xid.New().String() session.Type = 1000 err = dataprovider.AddSharedSession(session) if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid session type") } dbMgr, ok := mgr.(*dbOIDCManager) if assert.True(t, ok) { _, err = dbMgr.decodePendingAuthData(2) assert.Error(t, err) _, err = dbMgr.decodeTokenData(true) assert.Error(t, err) } } func getTestOIDCServer() *httpdServer { return &httpdServer{ binding: Binding{ OIDC: OIDC{ ClientID: "sftpgo-client", ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c", ConfigURL: fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr), RedirectBaseURL: "http://127.0.0.1:8081/", UsernameField: "preferred_username", RoleField: "sftpgo_role", ImplicitRoles: false, Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, CustomFields: nil, Debug: true, }, }, enableWebAdmin: true, enableWebClient: true, } } func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte { content := []byte("#!/bin/sh\n\n") if nonJSONResponse { content = append(content, []byte("echo 'text response'\n")...) return content } if len(user.Username) > 0 { u, _ := json.Marshal(user) content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) } return content } ================================================ FILE: internal/httpd/oidcmanager.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/json" "errors" "sync" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( oidcMgr oidcManager ) func newOIDCManager(isShared int) oidcManager { if isShared == 1 { logger.Info(logSender, "", "using provider OIDC manager") return &dbOIDCManager{} } logger.Info(logSender, "", "using memory OIDC manager") return &memoryOIDCManager{ pendingAuths: make(map[string]oidcPendingAuth), tokens: make(map[string]oidcToken), } } type oidcManager interface { addPendingAuth(pendingAuth oidcPendingAuth) removePendingAuth(state string) getPendingAuth(state string) (oidcPendingAuth, error) addToken(token oidcToken) getToken(cookie string) (oidcToken, error) removeToken(cookie string) updateTokenUsage(token oidcToken) cleanup() } type memoryOIDCManager struct { authMutex sync.RWMutex pendingAuths map[string]oidcPendingAuth tokenMutex sync.RWMutex tokens map[string]oidcToken } func (o *memoryOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) { o.authMutex.Lock() o.pendingAuths[pendingAuth.State] = pendingAuth o.authMutex.Unlock() } func (o *memoryOIDCManager) removePendingAuth(state string) { o.authMutex.Lock() defer o.authMutex.Unlock() delete(o.pendingAuths, state) } func (o *memoryOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) { o.authMutex.RLock() defer o.authMutex.RUnlock() authReq, ok := o.pendingAuths[state] if !ok { return oidcPendingAuth{}, errors.New("oidc: no auth request found for the specified state") } diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt if diff > authStateValidity { return oidcPendingAuth{}, errors.New("oidc: auth request is too old") } return authReq, nil } func (o *memoryOIDCManager) addToken(token oidcToken) { o.tokenMutex.Lock() token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) o.tokens[token.Cookie] = token o.tokenMutex.Unlock() } func (o *memoryOIDCManager) getToken(cookie string) (oidcToken, error) { o.tokenMutex.RLock() defer o.tokenMutex.RUnlock() token, ok := o.tokens[cookie] if !ok { return oidcToken{}, errors.New("oidc: no token found for the specified session") } diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt if diff > tokenDeleteInterval { return oidcToken{}, errors.New("oidc: token is too old") } return token, nil } func (o *memoryOIDCManager) removeToken(cookie string) { o.tokenMutex.Lock() defer o.tokenMutex.Unlock() delete(o.tokens, cookie) } func (o *memoryOIDCManager) updateTokenUsage(token oidcToken) { diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt if diff > tokenUpdateInterval { o.addToken(token) } } func (o *memoryOIDCManager) cleanup() { o.cleanupAuthRequests() o.cleanupTokens() } func (o *memoryOIDCManager) cleanupAuthRequests() { o.authMutex.Lock() defer o.authMutex.Unlock() for k, auth := range o.pendingAuths { diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt // remove old pending auth requests if diff < 0 || diff > authStateValidity { delete(o.pendingAuths, k) } } } func (o *memoryOIDCManager) cleanupTokens() { o.tokenMutex.Lock() defer o.tokenMutex.Unlock() for k, token := range o.tokens { diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt // remove tokens unused from more than tokenDeleteInterval if diff > tokenDeleteInterval { delete(o.tokens, k) } } } type dbOIDCManager struct{} func (o *dbOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) { session := dataprovider.Session{ Key: pendingAuth.State, Data: pendingAuth, Type: dataprovider.SessionTypeOIDCAuth, Timestamp: pendingAuth.IssuedAt + authStateValidity, } dataprovider.AddSharedSession(session) //nolint:errcheck } func (o *dbOIDCManager) removePendingAuth(state string) { dataprovider.DeleteSharedSession(state, dataprovider.SessionTypeOIDCAuth) //nolint:errcheck } func (o *dbOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) { session, err := dataprovider.GetSharedSession(state, dataprovider.SessionTypeOIDCAuth) if err != nil { return oidcPendingAuth{}, errors.New("oidc: unable to get the auth request for the specified state") } if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { // expired return oidcPendingAuth{}, errors.New("oidc: auth request is too old") } return o.decodePendingAuthData(session.Data) } func (o *dbOIDCManager) decodePendingAuthData(data any) (oidcPendingAuth, error) { if val, ok := data.([]byte); ok { authReq := oidcPendingAuth{} err := json.Unmarshal(val, &authReq) return authReq, err } logger.Error(logSender, "", "invalid oidc auth request data type %T", data) return oidcPendingAuth{}, errors.New("oidc: invalid auth request data") } func (o *dbOIDCManager) addToken(token oidcToken) { token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) session := dataprovider.Session{ Key: token.Cookie, Data: token, Type: dataprovider.SessionTypeOIDCToken, Timestamp: token.UsedAt + tokenDeleteInterval, } dataprovider.AddSharedSession(session) //nolint:errcheck } func (o *dbOIDCManager) removeToken(cookie string) { dataprovider.DeleteSharedSession(cookie, dataprovider.SessionTypeOIDCToken) //nolint:errcheck } func (o *dbOIDCManager) updateTokenUsage(token oidcToken) { diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt if diff > tokenUpdateInterval { o.addToken(token) } } func (o *dbOIDCManager) getToken(cookie string) (oidcToken, error) { session, err := dataprovider.GetSharedSession(cookie, dataprovider.SessionTypeOIDCToken) if err != nil { return oidcToken{}, errors.New("oidc: unable to get the token for the specified session") } if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { // expired return oidcToken{}, errors.New("oidc: token is too old") } return o.decodeTokenData(session.Data) } func (o *dbOIDCManager) decodeTokenData(data any) (oidcToken, error) { if val, ok := data.([]byte); ok { token := oidcToken{} err := json.Unmarshal(val, &token) return token, err } logger.Error(logSender, "", "invalid oidc token data type %T", data) return oidcToken{}, errors.New("oidc: invalid token data") } func (o *dbOIDCManager) cleanup() { dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCAuth, time.Now()) //nolint:errcheck dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCToken, time.Now()) //nolint:errcheck } ================================================ FILE: internal/httpd/resetcode.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/json" "sync" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( resetCodeLifespan = 10 * time.Minute resetCodesMgr resetCodeManager ) type resetCodeManager interface { Add(code *resetCode) error Get(code string) (*resetCode, error) Delete(code string) error Cleanup() } func newResetCodeManager(isShared int) resetCodeManager { if isShared == 1 { logger.Info(logSender, "", "using provider reset code manager") return &dbResetCodeManager{} } logger.Info(logSender, "", "using memory reset code manager") return &memoryResetCodeManager{} } type resetCode struct { Code string `json:"code"` Username string `json:"username"` IsAdmin bool `json:"is_admin"` ExpiresAt time.Time `json:"expires_at"` } func newResetCode(username string, isAdmin bool) *resetCode { return &resetCode{ Code: util.GenerateUniqueID(), Username: username, IsAdmin: isAdmin, ExpiresAt: time.Now().Add(resetCodeLifespan).UTC(), } } func (c *resetCode) isExpired() bool { return c.ExpiresAt.Before(time.Now().UTC()) } type memoryResetCodeManager struct { resetCodes sync.Map } func (m *memoryResetCodeManager) Add(code *resetCode) error { m.resetCodes.Store(code.Code, code) return nil } func (m *memoryResetCodeManager) Get(code string) (*resetCode, error) { c, ok := m.resetCodes.Load(code) if !ok { return nil, util.NewRecordNotFoundError("reset code not found") } return c.(*resetCode), nil } func (m *memoryResetCodeManager) Delete(code string) error { m.resetCodes.Delete(code) return nil } func (m *memoryResetCodeManager) Cleanup() { m.resetCodes.Range(func(key, value any) bool { c, ok := value.(*resetCode) if !ok || c.isExpired() { m.resetCodes.Delete(key) } return true }) } type dbResetCodeManager struct{} func (m *dbResetCodeManager) Add(code *resetCode) error { session := dataprovider.Session{ Key: code.Code, Data: code, Type: dataprovider.SessionTypeResetCode, Timestamp: util.GetTimeAsMsSinceEpoch(code.ExpiresAt), } return dataprovider.AddSharedSession(session) } func (m *dbResetCodeManager) Get(code string) (*resetCode, error) { session, err := dataprovider.GetSharedSession(code, dataprovider.SessionTypeResetCode) if err != nil { return nil, err } if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { // expired return nil, util.NewRecordNotFoundError("reset code expired") } return m.decodeData(session.Data) } func (m *dbResetCodeManager) decodeData(data any) (*resetCode, error) { if val, ok := data.([]byte); ok { c := &resetCode{} err := json.Unmarshal(val, c) return c, err } logger.Error(logSender, "", "invalid reset code data type %T", data) return nil, util.NewRecordNotFoundError("invalid reset code") } func (m *dbResetCodeManager) Delete(code string) error { return dataprovider.DeleteSharedSession(code, dataprovider.SessionTypeResetCode) } func (m *dbResetCodeManager) Cleanup() { dataprovider.CleanupSharedSessions(dataprovider.SessionTypeResetCode, time.Now()) //nolint:errcheck } ================================================ FILE: internal/httpd/resources.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !bundle package httpd import ( "net/http" "github.com/go-chi/chi/v5" ) func serveStaticDir(router chi.Router, path, fsDirPath string, disableDirectoryIndex bool) { fileServer(router, path, http.Dir(fsDirPath), disableDirectoryIndex) } ================================================ FILE: internal/httpd/resources_embedded.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build bundle package httpd import ( "net/http" "github.com/go-chi/chi/v5" "github.com/drakkan/sftpgo/v2/internal/bundle" ) func serveStaticDir(router chi.Router, path, fsDirPath string, disableDirectoryIndex bool) { switch path { case webStaticFilesPath: fileServer(router, path, bundle.GetStaticFs(), disableDirectoryIndex) case webOpenAPIPath: fileServer(router, path, bundle.GetOpenAPIFs(), disableDirectoryIndex) default: fileServer(router, path, http.Dir(fsDirPath), disableDirectoryIndex) } } ================================================ FILE: internal/httpd/server.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "crypto/rand" "crypto/tls" "crypto/x509" "errors" "fmt" "log" "net" "net/http" "net/url" "path" "path/filepath" "slices" "strings" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" "github.com/go-jose/go-jose/v4" "github.com/rs/cors" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/unrolled/secure" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( jsonAPISuffix = "/json" ) var ( compressor = middleware.NewCompressor(5) xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") ) type httpdServer struct { binding Binding staticFilesPath string openAPIPath string enableWebAdmin bool enableWebClient bool enableRESTAPI bool renderOpenAPI bool isShared int router *chi.Mux tokenAuth *jwt.Signer csrfTokenAuth *jwt.Signer signingPassphrase string cors CorsConfig } func newHttpdServer(b Binding, staticFilesPath, signingPassphrase string, cors CorsConfig, openAPIPath string, ) *httpdServer { if openAPIPath == "" { b.RenderOpenAPI = false } return &httpdServer{ binding: b, staticFilesPath: staticFilesPath, openAPIPath: openAPIPath, enableWebAdmin: b.EnableWebAdmin, enableWebClient: b.EnableWebClient, enableRESTAPI: b.EnableRESTAPI, renderOpenAPI: b.RenderOpenAPI, signingPassphrase: signingPassphrase, cors: cors, } } func (s *httpdServer) setShared(value int) { s.isShared = value } func (s *httpdServer) listenAndServe() error { if err := s.initializeRouter(); err != nil { return err } httpServer := &http.Server{ Handler: s.router, ReadHeaderTimeout: 30 * time.Second, IdleTimeout: 60 * time.Second, MaxHeaderBytes: 1 << 16, // 64KB ErrorLog: log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0), } if certMgr != nil && s.binding.EnableHTTPS { certID := common.DefaultTLSKeyPaidID if getConfigPath(s.binding.CertificateFile, "") != "" && getConfigPath(s.binding.CertificateKeyFile, "") != "" { certID = s.binding.GetAddress() } config := &tls.Config{ GetCertificate: certMgr.GetCertificateFunc(certID), MinVersion: util.GetTLSVersion(s.binding.MinTLSVersion), NextProtos: util.GetALPNProtocols(s.binding.Protocols), CipherSuites: util.GetTLSCiphersFromNames(s.binding.TLSCipherSuites), } httpServer.TLSConfig = config logger.Debug(logSender, "", "configured TLS cipher suites for binding %q: %v, certID: %v", s.binding.GetAddress(), httpServer.TLSConfig.CipherSuites, certID) if s.binding.isMutualTLSEnabled() { httpServer.TLSConfig.ClientCAs = certMgr.GetRootCAs() httpServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert httpServer.TLSConfig.VerifyConnection = s.verifyTLSConnection } return util.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, true, s.binding.listenerWrapper(), logSender) } return util.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, false, s.binding.listenerWrapper(), logSender) } func (s *httpdServer) verifyTLSConnection(state tls.ConnectionState) error { if certMgr != nil { var clientCrt *x509.Certificate var clientCrtName string if len(state.PeerCertificates) > 0 { clientCrt = state.PeerCertificates[0] clientCrtName = clientCrt.Subject.String() } if len(state.VerifiedChains) == 0 { logger.Warn(logSender, "", "TLS connection cannot be verified: unable to get verification chain") return errors.New("TLS connection cannot be verified: unable to get verification chain") } for _, verifiedChain := range state.VerifiedChains { var caCrt *x509.Certificate if len(verifiedChain) > 0 { caCrt = verifiedChain[len(verifiedChain)-1] } if certMgr.IsRevoked(clientCrt, caCrt) { logger.Debug(logSender, "", "tls handshake error, client certificate %q has been revoked", clientCrtName) return common.ErrCrtRevoked } } } return nil } func (s *httpdServer) refreshCookie(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.checkCookieExpiration(w, r) next.ServeHTTP(w, r) }) } func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := loginPage{ commonBasePage: getCommonBasePage(r), Title: util.I18nLoginTitle, CurrentURL: webClientLoginPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), FormDisabled: s.binding.isWebClientLoginFormDisabled(), CheckRedirect: true, } if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { data.CurrentURL += "?next=" + url.QueryEscape(next) } if s.binding.showAdminLoginURL() { data.AltLoginURL = webAdminLoginPath data.AltLoginName = s.binding.webAdminBranding().ShortName } if smtp.IsEnabled() && !data.FormDisabled { data.ForgotPwdURL = webClientForgotPwdPath } if s.binding.OIDC.isEnabled() && !s.binding.isWebClientOIDCLoginDisabled() { data.OpenIDLoginURL = webClientOIDCLoginPath } renderClientTemplate(w, templateCommonLogin, data) } func (s *httpdServer) handleWebClientLogout(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) removeCookie(w, r, webBaseClientPath) s.logoutOIDCUser(w, r) http.Redirect(w, r, webClientLoginPath, http.StatusFound) } func (s *httpdServer) handleWebClientChangePwdPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) if err := r.ParseForm(); err != nil { s.renderClientChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } err := doChangeUserPassword(r, strings.TrimSpace(r.Form.Get("current_password")), strings.TrimSpace(r.Form.Get("new_password1")), strings.TrimSpace(r.Form.Get("new_password2"))) if err != nil { s.renderClientChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) return } s.handleWebClientLogout(w, r) } func (s *httpdServer) handleClientWebLogin(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) if !dataprovider.HasAdmin() { http.Redirect(w, r, webAdminSetupPath, http.StatusFound) return } msg := getFlashMessage(w, r) s.renderClientLoginPage(w, r, msg.getI18nError()) } func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } protocol := common.ProtocolHTTP username := strings.TrimSpace(r.Form.Get("username")) password := r.Form.Get("password") if username == "" || password == "" { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials, r) s.renderClientLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err, r) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) } if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err, r) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message)) return } user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol) if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) s.renderClientLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message)) return } defer user.CloseFs() //nolint:errcheck err = user.CheckFsRoot(connectionID) if err != nil { logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorFsGeneric)) return } s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage) } func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err := r.ParseForm() if err != nil { s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } newPassword := strings.TrimSpace(r.Form.Get("password")) confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password")) _, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")), newPassword, confirmPassword, false) if err != nil { s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) return } connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String()) if err := checkHTTPClientUser(user, r, connectionID, true, false); err != nil { s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset)) return } defer user.CloseFs() //nolint:errcheck err = user.CheckFsRoot(connectionID) if err != nil { logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset)) return } s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage) } func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) claims, err := jwt.FromContext(r.Context()) if err != nil { s.renderNotFoundPage(w, r, nil) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := claims.Username recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code")) if username == "" || recoveryCode == "" { s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } user, userMerged, err := dataprovider.GetUserVariants(username, "") if err != nil { if errors.Is(err, util.ErrNotFound) { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck } s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if !userMerged.Filters.TOTPConfig.Enabled || !slices.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { s.renderClientTwoFactorPage(w, r, util.NewI18nError( util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled)) return } for idx, code := range user.Filters.RecoveryCodes { if err := code.Secret.Decrypt(); err != nil { s.renderClientInternalServerErrorPage(w, r, fmt.Errorf("unable to decrypt recovery code: %w", err)) return } if code.Secret.GetPayload() == recoveryCode { if code.Used { s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } user.Filters.RecoveryCodes[idx].Used = true err = dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, ipAddr, user.Role) if err != nil { logger.Warn(logSender, "", "unable to set the recovery code %q as used: %v", recoveryCode, err) s.renderClientInternalServerErrorPage(w, r, errors.New("unable to set the recovery code as used")) return } connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String()) s.loginUser(w, r, &userMerged, connectionID, ipAddr, true, s.renderClientTwoFactorRecoveryPage) return } } handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) } func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) claims, err := jwt.FromContext(r.Context()) if err != nil { s.renderNotFoundPage(w, r, nil) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := claims.Username passcode := strings.TrimSpace(r.Form.Get("passcode")) if username == "" || passcode == "" { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials, r) s.renderClientTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err, r) s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } user, err := dataprovider.GetUserWithGroupSettings(username, "") if err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err, r) s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) return } if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) s.renderClientTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled)) return } err = user.Filters.TOTPConfig.Secret.Decrypt() if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) s.renderClientInternalServerErrorPage(w, r, err) return } match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, passcode, user.Filters.TOTPConfig.Secret.GetPayload()) if !match || err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials, r) s.renderClientTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String()) s.loginUser(w, r, &user, connectionID, ipAddr, true, s.renderClientTwoFactorPage) } func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) claims, err := jwt.FromContext(r.Context()) if err != nil { s.renderNotFoundPage(w, r, nil) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := claims.Username recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code")) if username == "" || recoveryCode == "" { s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } admin, err := dataprovider.AdminExists(username) if err != nil { if errors.Is(err, util.ErrNotFound) { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck } s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if !admin.Filters.TOTPConfig.Enabled { s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled)) return } for idx, code := range admin.Filters.RecoveryCodes { if err := code.Secret.Decrypt(); err != nil { s.renderInternalServerErrorPage(w, r, fmt.Errorf("unable to decrypt recovery code: %w", err)) return } if code.Secret.GetPayload() == recoveryCode { if code.Used { s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } admin.Filters.RecoveryCodes[idx].Used = true err = dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, ipAddr, admin.Role) if err != nil { logger.Warn(logSender, "", "unable to set the recovery code %q as used: %v", recoveryCode, err) s.renderInternalServerErrorPage(w, r, errors.New("unable to set the recovery code as used")) return } s.loginAdmin(w, r, &admin, true, s.renderTwoFactorRecoveryPage, ipAddr) return } } handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) } func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) claims, err := jwt.FromContext(r.Context()) if err != nil { s.renderNotFoundPage(w, r, nil) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := claims.Username passcode := strings.TrimSpace(r.Form.Get("passcode")) if username == "" || passcode == "" { s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } admin, err := dataprovider.AdminExists(username) if err != nil { if errors.Is(err, util.ErrNotFound) { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck } s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) return } if !admin.Filters.TOTPConfig.Enabled { s.renderTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled)) return } err = admin.Filters.TOTPConfig.Secret.Decrypt() if err != nil { s.renderInternalServerErrorPage(w, r, err) return } match, err := mfa.ValidateTOTPPasscode(admin.Filters.TOTPConfig.ConfigName, passcode, admin.Filters.TOTPConfig.Secret.GetPayload()) if !match || err != nil { handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } s.loginAdmin(w, r, &admin, true, s.renderTwoFactorPage, ipAddr) } func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := strings.TrimSpace(r.Form.Get("username")) password := strings.TrimSpace(r.Form.Get("password")) if username == "" || password == "" { s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } admin, err := dataprovider.CheckAdminAndPass(username, password, ipAddr) if err != nil { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } s.loginAdmin(w, r, &admin, false, s.renderAdminLoginPage, ipAddr) } func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := loginPage{ commonBasePage: getCommonBasePage(r), Title: util.I18nLoginTitle, CurrentURL: webAdminLoginPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), Branding: s.binding.webAdminBranding(), Languages: s.binding.languages(), FormDisabled: s.binding.isWebAdminLoginFormDisabled(), CheckRedirect: false, } if s.binding.showClientLoginURL() { data.AltLoginURL = webClientLoginPath data.AltLoginName = s.binding.webClientBranding().ShortName } if smtp.IsEnabled() && !data.FormDisabled { data.ForgotPwdURL = webAdminForgotPwdPath } if s.binding.OIDC.hasRoles() && !s.binding.isWebAdminOIDCLoginDisabled() { data.OpenIDLoginURL = webAdminOIDCLoginPath } renderAdminTemplate(w, templateCommonLogin, data) } func (s *httpdServer) handleWebAdminLogin(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) if !dataprovider.HasAdmin() { http.Redirect(w, r, webAdminSetupPath, http.StatusFound) return } msg := getFlashMessage(w, r) s.renderAdminLoginPage(w, r, msg.getI18nError()) } func (s *httpdServer) handleWebAdminLogout(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) removeCookie(w, r, webBaseAdminPath) s.logoutOIDCUser(w, r) http.Redirect(w, r, webAdminLoginPath, http.StatusFound) } func (s *httpdServer) handleWebAdminChangePwdPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) err := r.ParseForm() if err != nil { s.renderChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } err = doChangeAdminPassword(r, strings.TrimSpace(r.Form.Get("current_password")), strings.TrimSpace(r.Form.Get("new_password1")), strings.TrimSpace(r.Form.Get("new_password2"))) if err != nil { s.renderChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) return } s.handleWebAdminLogout(w, r) } func (s *httpdServer) handleWebAdminPasswordResetPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err := r.ParseForm() if err != nil { s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } newPassword := strings.TrimSpace(r.Form.Get("password")) confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password")) admin, _, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")), newPassword, confirmPassword, true) if err != nil { s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) return } s.loginAdmin(w, r, admin, false, s.renderResetPwdPage, ipAddr) } func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) if dataprovider.HasAdmin() { s.renderBadRequestPage(w, r, errors.New("an admin user already exists")) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err := r.ParseForm() if err != nil { s.renderAdminSetupPage(w, r, "", util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } username := strings.TrimSpace(r.Form.Get("username")) password := strings.TrimSpace(r.Form.Get("password")) confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password")) installCode := strings.TrimSpace(r.Form.Get("install_code")) if installationCode != "" && installCode != resolveInstallationCode() { s.renderAdminSetupPage(w, r, username, util.NewI18nError( util.NewValidationError(fmt.Sprintf("%v mismatch", installationCodeHint)), util.I18nErrorSetupInstallCode), ) return } if username == "" { s.renderAdminSetupPage(w, r, username, util.NewI18nError(util.NewValidationError("please set a username"), util.I18nError500Message)) return } if password == "" { s.renderAdminSetupPage(w, r, username, util.NewI18nError(util.NewValidationError("please set a password"), util.I18nError500Message)) return } if password != confirmPassword { s.renderAdminSetupPage(w, r, username, util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch)) return } admin := dataprovider.Admin{ Username: username, Password: password, Status: 1, Permissions: []string{dataprovider.PermAdminAny}, } err = dataprovider.AddAdmin(&admin, username, ipAddr, "") if err != nil { s.renderAdminSetupPage(w, r, username, util.NewI18nError(err, util.I18nError500Message)) return } s.loginAdmin(w, r, &admin, false, nil, ipAddr) } func (s *httpdServer) loginUser( w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string, isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError), ) { c := &jwt.Claims{ Username: user.Username, Permissions: user.Filters.WebClient, Role: user.Role, MustSetTwoFactorAuth: user.MustSetSecondFactor(), MustChangePassword: user.MustChangePassword(), RequiredTwoFactorProtocols: user.Filters.TwoFactorAuthProtocols, } c.Subject = user.GetSignature() audience := tokenAudienceWebClient if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) && user.CanManageMFA() && !isSecondFactorAuth { audience = tokenAudienceWebClientPartial } err := createAndSetCookie(w, r, c, s.tokenAuth, audience, ipAddr) if err != nil { logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err) updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message)) return } invalidateToken(r) if audience == tokenAudienceWebClientPartial { redirectPath := webClientTwoFactorPath if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { redirectPath += "?next=" + url.QueryEscape(next) } http.Redirect(w, r, redirectPath, http.StatusFound) return } updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, err, r) dataprovider.UpdateLastLogin(user) if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { http.Redirect(w, r, next, http.StatusFound) return } http.Redirect(w, r, webClientFilesPath, http.StatusFound) } func (s *httpdServer) loginAdmin( w http.ResponseWriter, r *http.Request, admin *dataprovider.Admin, isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError), ipAddr string, ) { c := &jwt.Claims{ Username: admin.Username, Permissions: admin.Permissions, Role: admin.Role, HideUserPageSections: admin.Filters.Preferences.HideUserPageSections, MustSetTwoFactorAuth: admin.Filters.RequireTwoFactor && !admin.Filters.TOTPConfig.Enabled, MustChangePassword: admin.Filters.RequirePasswordChange, } c.Subject = admin.GetSignature() audience := tokenAudienceWebAdmin if admin.Filters.TOTPConfig.Enabled && admin.CanManageMFA() && !isSecondFactorAuth { audience = tokenAudienceWebAdminPartial } err := createAndSetCookie(w, r, c, s.tokenAuth, audience, ipAddr) if err != nil { logger.Warn(logSender, "", "unable to set admin login cookie %v", err) if errorFunc == nil { s.renderAdminSetupPage(w, r, admin.Username, util.NewI18nError(err, util.I18nError500Message)) return } errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message)) return } invalidateToken(r) if audience == tokenAudienceWebAdminPartial { http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound) return } dataprovider.UpdateAdminLastLogin(admin) common.DelayLogin(nil) redirectURL := webUsersPath if errorFunc == nil { redirectURL = webAdminMFAPath } http.Redirect(w, r, redirectURL, http.StatusFound) } func (s *httpdServer) logout(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) invalidateToken(r) sendAPIResponse(w, r, nil, "Your token has been invalidated", http.StatusOK) } func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) username, password, ok := r.BasicAuth() protocol := common.ProtocolHTTP if !ok { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials, r) w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } if username == "" || strings.TrimSpace(password) == "" { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials, r) w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err, r) sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol) if err != nil { w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { passcode := r.Header.Get(otpHeaderCode) if passcode == "" { logger.Debug(logSender, "", "TOTP enabled for user %q and not passcode provided, authentication refused", user.Username) w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials, r) sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } err = user.Filters.TOTPConfig.Secret.Decrypt() if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) sendAPIResponse(w, r, fmt.Errorf("unable to decrypt TOTP secret: %w", err), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, passcode, user.Filters.TOTPConfig.Secret.GetPayload()) if !match || err != nil { logger.Debug(logSender, "invalid passcode for user %q, match? %v, err: %v", user.Username, match, err) w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials, r) sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } } defer user.CloseFs() //nolint:errcheck err = user.CheckFsRoot(connectionID) if err != nil { logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } s.generateAndSendUserToken(w, r, ipAddr, user) } func (s *httpdServer) generateAndSendUserToken(w http.ResponseWriter, r *http.Request, ipAddr string, user dataprovider.User) { c := &jwt.Claims{ Username: user.Username, Permissions: user.Filters.WebClient, Role: user.Role, MustSetTwoFactorAuth: user.MustSetSecondFactor(), MustChangePassword: user.MustChangePassword(), RequiredTwoFactorProtocols: user.Filters.TwoFactorAuthProtocols, } c.Subject = user.GetSignature() token, err := s.tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser)) if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) dataprovider.UpdateLastLogin(&user) render.JSON(w, r, c.BuildTokenResponse(token)) } func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) { username, password, ok := r.BasicAuth() if !ok { w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) admin, err := dataprovider.CheckAdminAndPass(username, password, ipAddr) if err != nil { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } if admin.Filters.TOTPConfig.Enabled { passcode := r.Header.Get(otpHeaderCode) if passcode == "" { logger.Debug(logSender, "", "TOTP enabled for admin %q and not passcode provided, authentication refused", admin.Username) w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) err = handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } err = admin.Filters.TOTPConfig.Secret.Decrypt() if err != nil { sendAPIResponse(w, r, fmt.Errorf("unable to decrypt TOTP secret: %w", err), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } match, err := mfa.ValidateTOTPPasscode(admin.Filters.TOTPConfig.ConfigName, passcode, admin.Filters.TOTPConfig.Secret.GetPayload()) if !match || err != nil { logger.Debug(logSender, "invalid passcode for admin %q, match? %v, err: %v", admin.Username, match, err) w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) err = handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } } s.generateAndSendToken(w, r, admin, ipAddr) } func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin, ip string) { c := &jwt.Claims{ Username: admin.Username, Permissions: admin.Permissions, Role: admin.Role, MustSetTwoFactorAuth: admin.Filters.RequireTwoFactor && !admin.Filters.TOTPConfig.Enabled, MustChangePassword: admin.Filters.RequirePasswordChange, } c.Subject = admin.GetSignature() token, err := s.tokenAuth.SignWithParams(c, tokenAudienceAPI, ip, getTokenDuration(tokenAudienceAPI)) if err != nil { sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } dataprovider.UpdateAdminLastLogin(&admin) common.DelayLogin(nil) render.JSON(w, r, c.BuildTokenResponse(token)) } func (s *httpdServer) checkCookieExpiration(w http.ResponseWriter, r *http.Request) { if _, ok := r.Context().Value(oidcTokenKey).(string); ok { return } claims, err := jwt.FromContext(r.Context()) if err != nil { return } if claims.Username == "" || claims.Subject == "" { return } if time.Until(claims.Expiry.Time()) > cookieRefreshThreshold { return } if (time.Since(claims.IssuedAt.Time()) + cookieTokenDuration) > maxTokenDuration { return } if claims.Audience.Contains(tokenAudienceWebClient) { s.refreshClientToken(w, r, claims) } else { s.refreshAdminToken(w, r, claims) } } func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwt.Claims) { user, err := dataprovider.GetUserWithGroupSettings(tokenClaims.Username, "") if err != nil { return } if user.GetSignature() != tokenClaims.Subject { logger.Debug(logSender, "", "signature mismatch for user %q, unable to refresh cookie", user.Username) return } if err := user.CheckLoginConditions(); err != nil { logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err) return } if err := checkHTTPClientUser(&user, r, xid.New().String(), true, false); err != nil { logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err) return } tokenClaims.Permissions = user.Filters.WebClient tokenClaims.Role = user.Role logger.Debug(logSender, "", "cookie refreshed for user %q", user.Username) createAndSetCookie(w, r, tokenClaims, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck } func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwt.Claims) { admin, err := dataprovider.AdminExists(tokenClaims.Username) if err != nil { return } if admin.GetSignature() != tokenClaims.Subject { logger.Debug(logSender, "", "signature mismatch for admin %q, unable to refresh cookie", admin.Username) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := admin.CanLogin(ipAddr); err != nil { logger.Debug(logSender, "", "unable to refresh cookie for admin %q, err: %v", admin.Username, err) return } tokenClaims.Permissions = admin.Permissions tokenClaims.Role = admin.Role tokenClaims.HideUserPageSections = admin.Filters.Preferences.HideUserPageSections logger.Debug(logSender, "", "cookie refreshed for admin %q", admin.Username) createAndSetCookie(w, r, tokenClaims, s.tokenAuth, tokenAudienceWebAdmin, ipAddr) //nolint:errcheck } func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request { _, err := jwt.FromContext(r.Context()) if err != nil { _, err = r.Cookie(jwt.CookieKey) if err != nil { return r } token, err := jwt.VerifyRequest(s.tokenAuth, r, jwt.TokenFromCookie) ctx := jwt.NewContext(r.Context(), token, err) return r.WithContext(ctx) } return r } func (s *httpdServer) parseHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { responseControllerDeadlines( http.NewResponseController(w), time.Now().Add(60*time.Second), time.Now().Add(60*time.Second), ) w.Header().Set("Server", version.GetServerVersion("/", false)) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) var ip net.IP isUnixSocket := filepath.IsAbs(s.binding.Address) if !isUnixSocket { ip = net.ParseIP(ipAddr) } areHeadersAllowed := false if isUnixSocket || ip != nil { for _, allow := range s.binding.allowHeadersFrom { if allow(ip) { parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth) if parsedIP != "" { ipAddr = parsedIP r.RemoteAddr = ipAddr } if forwardedProto := r.Header.Get(xForwardedProto); forwardedProto != "" { ctx := context.WithValue(r.Context(), forwardedProtoKey, forwardedProto) r = r.WithContext(ctx) } areHeadersAllowed = true break } } } if !areHeadersAllowed { for idx := range s.binding.Security.proxyHeaders { r.Header.Del(s.binding.Security.proxyHeaders[idx]) } } next.ServeHTTP(w, r) }) } func (s *httpdServer) checkConnection(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) common.Connections.AddClientConnection(ipAddr) defer common.Connections.RemoveClientConnection(ipAddr) if err := common.Connections.IsNewConnectionAllowed(ipAddr, common.ProtocolHTTP); err != nil { logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection not allowed from ip %q: %v", ipAddr, err) s.sendForbiddenResponse(w, r, util.NewI18nError(err, util.I18nErrorConnectionForbidden)) return } if common.IsBanned(ipAddr, common.ProtocolHTTP) { s.sendForbiddenResponse(w, r, util.NewI18nError( util.NewGenericError("your IP address is blocked"), util.I18nErrorIPForbidden), ) return } if delay, err := common.LimitRate(common.ProtocolHTTP, ipAddr); err != nil { delay += 499999999 * time.Nanosecond w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds())) w.Header().Set("X-Retry-In", delay.String()) s.sendTooManyRequestResponse(w, r, err) return } next.ServeHTTP(w, r) }) } func (s *httpdServer) sendTooManyRequestResponse(w http.ResponseWriter, r *http.Request, err error) { if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { r = s.updateContextFromCookie(r) if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(errors.New(http.StatusText(http.StatusTooManyRequests)), util.I18nError429Message), "") return } s.renderMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(errors.New(http.StatusText(http.StatusTooManyRequests)), util.I18nError429Message), "") return } sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) } func (s *httpdServer) sendForbiddenResponse(w http.ResponseWriter, r *http.Request, err error) { if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { r = s.updateContextFromCookie(r) if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { s.renderClientForbiddenPage(w, r, err) return } s.renderForbiddenPage(w, r, err) return } sendAPIResponse(w, r, err, "", http.StatusForbidden) } func (s *httpdServer) badHostHandler(w http.ResponseWriter, r *http.Request) { host := r.Host for _, header := range s.binding.Security.HostsProxyHeaders { if h := r.Header.Get(header); h != "" { host = h break } } logger.Debug(logSender, "", "the host %q is not allowed", host) s.sendForbiddenResponse(w, r, util.NewI18nError( util.NewGenericError(http.StatusText(http.StatusForbidden)), util.I18nErrorConnectionForbidden, )) } func (s *httpdServer) notFoundHandler(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { r = s.updateContextFromCookie(r) if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { s.renderClientNotFoundPage(w, r, nil) return } s.renderNotFoundPage(w, r, nil) return } sendAPIResponse(w, r, nil, http.StatusText(http.StatusNotFound), http.StatusNotFound) } func (s *httpdServer) redirectToWebPath(w http.ResponseWriter, r *http.Request, webPath string) { if dataprovider.HasAdmin() { http.Redirect(w, r, webPath, http.StatusFound) return } if s.enableWebAdmin { http.Redirect(w, r, webAdminSetupPath, http.StatusFound) } } // The StripSlashes causes infinite redirects at the root path if used with http.FileServer. // We also don't strip paths with more than one trailing slash, see #1434 func (s *httpdServer) mustStripSlash(r *http.Request) bool { urlPath := getURLPath(r) return !strings.HasSuffix(urlPath, "//") && !strings.HasPrefix(urlPath, webOpenAPIPath) && !strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI) } func (s *httpdServer) mustCheckPath(r *http.Request) bool { urlPath := getURLPath(r) return !strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI) } func (s *httpdServer) initializeRouter() error { signer, err := jwt.NewSigner(jose.HS256, getSigningKey(s.signingPassphrase)) if err != nil { return err } csrfSigner, err := jwt.NewSigner(jose.HS256, getSigningKey(s.signingPassphrase)) if err != nil { return err } var hasHTTPSRedirect bool s.tokenAuth = signer s.csrfTokenAuth = csrfSigner s.router = chi.NewRouter() s.router.Use(middleware.RequestID) s.router.Use(s.parseHeaders) s.router.Use(logger.NewStructuredLogger(logger.GetLogger())) s.router.Use(middleware.Recoverer) if s.binding.Security.Enabled { secureMiddleware := secure.New(secure.Options{ AllowedHosts: s.binding.Security.AllowedHosts, AllowedHostsAreRegex: s.binding.Security.AllowedHostsAreRegex, HostsProxyHeaders: s.binding.Security.HostsProxyHeaders, SSLProxyHeaders: s.binding.Security.getHTTPSProxyHeaders(), STSSeconds: s.binding.Security.STSSeconds, STSIncludeSubdomains: s.binding.Security.STSIncludeSubdomains, STSPreload: s.binding.Security.STSPreload, ContentTypeNosniff: s.binding.Security.ContentTypeNosniff, ContentSecurityPolicy: s.binding.Security.ContentSecurityPolicy, PermissionsPolicy: s.binding.Security.PermissionsPolicy, CrossOriginOpenerPolicy: s.binding.Security.CrossOriginOpenerPolicy, CrossOriginResourcePolicy: s.binding.Security.CrossOriginResourcePolicy, CrossOriginEmbedderPolicy: s.binding.Security.CrossOriginEmbedderPolicy, ReferrerPolicy: s.binding.Security.ReferrerPolicy, }) secureMiddleware.SetBadHostHandler(http.HandlerFunc(s.badHostHandler)) if s.binding.Security.CacheControl == "private" { s.router.Use(cacheControlMiddleware) } s.router.Use(secureMiddleware.Handler) if s.binding.Security.HTTPSRedirect { s.router.Use(s.binding.Security.redirectHandler) hasHTTPSRedirect = true } } if s.cors.Enabled { c := cors.New(cors.Options{ AllowedOrigins: util.RemoveDuplicates(s.cors.AllowedOrigins, true), AllowedMethods: util.RemoveDuplicates(s.cors.AllowedMethods, true), AllowedHeaders: util.RemoveDuplicates(s.cors.AllowedHeaders, true), ExposedHeaders: util.RemoveDuplicates(s.cors.ExposedHeaders, true), MaxAge: s.cors.MaxAge, AllowCredentials: s.cors.AllowCredentials, OptionsPassthrough: s.cors.OptionsPassthrough, OptionsSuccessStatus: s.cors.OptionsSuccessStatus, AllowPrivateNetwork: s.cors.AllowPrivateNetwork, }) s.router.Use(c.Handler) } s.router.Use(middleware.Maybe(s.checkConnection, s.mustCheckPath)) s.router.Use(middleware.GetHead) s.router.Use(middleware.Maybe(middleware.StripSlashes, s.mustStripSlash)) s.router.NotFound(s.notFoundHandler) s.router.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) { render.PlainText(w, r, "ok") }) if hasHTTPSRedirect { if p := acme.GetHTTP01WebRoot(); p != "" { serveStaticDir(s.router, acmeChallengeURI, p, true) } } s.setupRESTAPIRoutes() if s.enableWebAdmin || s.enableWebClient { s.router.Group(func(router chi.Router) { router.Use(cleanCacheControlMiddleware) router.Use(compressor.Handler) serveStaticDir(router, webStaticFilesPath, s.staticFilesPath, true) }) if s.binding.OIDC.isEnabled() { s.router.Get(webOIDCRedirectPath, s.handleOIDCRedirect) } if s.enableWebClient { s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.redirectToWebPath(w, r, webClientLoginPath) }) s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.redirectToWebPath(w, r, webClientLoginPath) }) } else { s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.redirectToWebPath(w, r, webAdminLoginPath) }) s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.redirectToWebPath(w, r, webAdminLoginPath) }) } } s.setupWebClientRoutes() s.setupWebAdminRoutes() return nil } func (s *httpdServer) setupRESTAPIRoutes() { if s.enableRESTAPI { if !s.binding.isAdminTokenEndpointDisabled() { s.router.Get(tokenPath, s.getToken) s.router.Post(adminPath+"/{username}/forgot-password", forgotAdminPassword) s.router.Post(adminPath+"/{username}/reset-password", resetAdminPassword) } s.router.Group(func(router chi.Router) { router.Use(checkNodeToken(s.tokenAuth)) if !s.binding.isAdminAPIKeyAuthDisabled() { router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeAdmin)) } router.Use(jwt.Verify(s.tokenAuth, jwt.TokenFromHeader)) router.Use(jwtAuthenticatorAPI) router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) render.JSON(w, r, version.Get()) }) router.With(forbidAPIKeyAuthentication).Get(logoutPath, s.logout) router.With(forbidAPIKeyAuthentication).Get(adminProfilePath, getAdminProfile) router.With(forbidAPIKeyAuthentication, s.checkAuthRequirements).Put(adminProfilePath, updateAdminProfile) router.With(forbidAPIKeyAuthentication).Put(adminPwdPath, changeAdminPassword) // admin TOTP APIs router.With(forbidAPIKeyAuthentication).Get(adminTOTPConfigsPath, getTOTPConfigs) router.With(forbidAPIKeyAuthentication).Post(adminTOTPGeneratePath, generateTOTPSecret) router.With(forbidAPIKeyAuthentication).Post(adminTOTPValidatePath, validateTOTPPasscode) router.With(forbidAPIKeyAuthentication).Post(adminTOTPSavePath, saveTOTPConfig) router.With(forbidAPIKeyAuthentication).Get(admin2FARecoveryCodesPath, getRecoveryCodes) router.With(forbidAPIKeyAuthentication).Post(admin2FARecoveryCodesPath, generateRecoveryCodes) router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). Get(apiKeysPath, getAPIKeys) router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). Post(apiKeysPath, addAPIKey) router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). Get(apiKeysPath+"/{id}", getAPIKeyByID) router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). Put(apiKeysPath+"/{id}", updateAPIKey) router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). Delete(apiKeysPath+"/{id}", deleteAPIKey) router.Group(func(router chi.Router) { router.Use(s.checkAuthRequirements) router.With(s.checkPerms(dataprovider.PermAdminViewServerStatus)). Get(serverStatusPath, func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) render.JSON(w, r, getServicesStatus()) }) router.With(s.checkPerms(dataprovider.PermAdminViewConnections)).Get(activeConnectionsPath, getActiveConnections) router.With(s.checkPerms(dataprovider.PermAdminCloseConnections)). Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection) router.With(s.checkPerms(dataprovider.PermAdminQuotaScans)).Get(quotasBasePath+"/users/scans", getUsersQuotaScans) router.With(s.checkPerms(dataprovider.PermAdminQuotaScans)).Post(quotasBasePath+"/users/{username}/scan", startUserQuotaScan) router.With(s.checkPerms(dataprovider.PermAdminQuotaScans)).Get(quotasBasePath+"/folders/scans", getFoldersQuotaScans) router.With(s.checkPerms(dataprovider.PermAdminQuotaScans)).Post(quotasBasePath+"/folders/{name}/scan", startFolderQuotaScan) router.With(s.checkPerms(dataprovider.PermAdminViewUsers)).Get(userPath, getUsers) router.With(s.checkPerms(dataprovider.PermAdminAddUsers)).Post(userPath, addUser) router.With(s.checkPerms(dataprovider.PermAdminViewUsers)).Get(userPath+"/{username}", getUserByUsername) //nolint:goconst router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Put(userPath+"/{username}", updateUser) router.With(s.checkPerms(dataprovider.PermAdminDeleteUsers)).Delete(userPath+"/{username}", deleteUser) router.With(s.checkPerms(dataprovider.PermAdminDisableMFA)).Put(userPath+"/{username}/2fa/disable", disableUser2FA) //nolint:goconst router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Get(folderPath, getFolders) router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Get(folderPath+"/{name}", getFolderByName) //nolint:goconst router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Post(folderPath, addFolder) router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Put(folderPath+"/{name}", updateFolder) router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Delete(folderPath+"/{name}", deleteFolder) router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Get(groupPath, getGroups) router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Get(groupPath+"/{name}", getGroupByName) router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Post(groupPath, addGroup) router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Put(groupPath+"/{name}", updateGroup) router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Delete(groupPath+"/{name}", deleteGroup) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(dumpDataPath, dumpData) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(loadDataPath, loadData) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(loadDataPath, loadDataFromRequest) router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/users/{username}/usage", updateUserQuotaUsage) router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/users/{username}/transfer-usage", updateUserTransferQuotaUsage) router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/folders/{name}/usage", updateFolderQuotaUsage) router.With(s.checkPerms(dataprovider.PermAdminViewDefender)).Get(defenderHosts, getDefenderHosts) router.With(s.checkPerms(dataprovider.PermAdminViewDefender)).Get(defenderHosts+"/{id}", getDefenderHostByID) router.With(s.checkPerms(dataprovider.PermAdminManageDefender)).Delete(defenderHosts+"/{id}", deleteDefenderHostByID) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(adminPath, getAdmins) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(adminPath, addAdmin) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(adminPath+"/{username}", getAdminByUsername) router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(adminPath+"/{username}", updateAdmin) router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(adminPath+"/{username}", deleteAdmin) router.With(s.checkPerms(dataprovider.PermAdminDisableMFA)).Put(adminPath+"/{username}/2fa/disable", disableAdmin2FA) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(retentionChecksPath, getRetentionChecks) router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler). Get(fsEventsPath, searchFsEvents) router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler). Get(providerEventsPath, searchProviderEvents) router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler). Get(logEventsPath, searchLogEvents) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(eventActionsPath, getEventActions) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(eventActionsPath+"/{name}", getEventActionByName) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(eventActionsPath, addEventAction) router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(eventActionsPath+"/{name}", updateEventAction) router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(eventActionsPath+"/{name}", deleteEventAction) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(eventRulesPath, getEventRules) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(eventRulesPath+"/{name}", getEventRuleByName) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(eventRulesPath, addEventRule) router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(eventRulesPath+"/{name}", updateEventRule) router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(eventRulesPath+"/{name}", deleteEventRule) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(eventRulesPath+"/run/{name}", runOnDemandRule) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(rolesPath, getRoles) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(rolesPath, addRole) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(rolesPath+"/{name}", getRoleByName) router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(rolesPath+"/{name}", updateRole) router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(rolesPath+"/{name}", deleteRole) router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler).Get(ipListsPath+"/{type}", getIPListEntries) //nolint:goconst router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(ipListsPath+"/{type}", addIPListEntry) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(ipListsPath+"/{type}/{ipornet}", getIPListEntry) //nolint:goconst router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(ipListsPath+"/{type}/{ipornet}", updateIPListEntry) router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(ipListsPath+"/{type}/{ipornet}", deleteIPListEntry) }) }) // share API available to external users s.router.Get(sharesPath+"/{id}", s.downloadFromShare) s.router.Post(sharesPath+"/{id}", s.uploadFilesToShare) s.router.Post(sharesPath+"/{id}/{name}", s.uploadFileToShare) s.router.With(compressor.Handler).Get(sharesPath+"/{id}/dirs", s.readBrowsableShareContents) s.router.Get(sharesPath+"/{id}/files", s.downloadBrowsableSharedFile) if !s.binding.isUserTokenEndpointDisabled() { s.router.Get(userTokenPath, s.getUserToken) s.router.Post(userPath+"/{username}/forgot-password", forgotUserPassword) s.router.Post(userPath+"/{username}/reset-password", resetUserPassword) } s.router.Group(func(router chi.Router) { if !s.binding.isUserAPIKeyAuthDisabled() { router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeUser)) } router.Use(jwt.Verify(s.tokenAuth, jwt.TokenFromHeader)) router.Use(jwtAuthenticatorAPIUser) router.With(forbidAPIKeyAuthentication).Get(userLogoutPath, s.logout) router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientPasswordChangeDisabled)). Put(userPwdPath, changeUserPassword) router.With(forbidAPIKeyAuthentication).Get(userProfilePath, getUserProfile) router.With(forbidAPIKeyAuthentication, s.checkAuthRequirements).Put(userProfilePath, updateUserProfile) // user TOTP APIs router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). Get(userTOTPConfigsPath, getTOTPConfigs) router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). Post(userTOTPGeneratePath, generateTOTPSecret) router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). Post(userTOTPValidatePath, validateTOTPPasscode) router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). Post(userTOTPSavePath, saveTOTPConfig) router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). Get(user2FARecoveryCodesPath, getRecoveryCodes) router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). Post(user2FARecoveryCodesPath, generateRecoveryCodes) router.With(s.checkAuthRequirements, compressor.Handler).Get(userDirsPath, readUserFolder) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Post(userDirsPath, createUserDir) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Patch(userDirsPath, renameUserFsEntry) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Delete(userDirsPath, deleteUserDir) router.With(s.checkAuthRequirements).Get(userFilesPath, getUserFile) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Post(userFilesPath, uploadUserFiles) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Patch(userFilesPath, renameUserFsEntry) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Delete(userFilesPath, deleteUserFile) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Post(userFileActionsPath+"/move", renameUserFsEntry) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Post(userFileActionsPath+"/copy", copyUserFsEntry) router.With(s.checkAuthRequirements).Post(userStreamZipPath, getUserFilesAsZipStream) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). Get(userSharesPath, getShares) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). Post(userSharesPath, addShare) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). Get(userSharesPath+"/{id}", getShareByID) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). Put(userSharesPath+"/{id}", updateShare) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). Delete(userSharesPath+"/{id}", deleteShare) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Post(userUploadFilePath, uploadUserFile) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). Patch(userFilesDirsMetadataPath, setFileDirMetadata) }) if s.renderOpenAPI { s.router.Group(func(router chi.Router) { router.Use(cleanCacheControlMiddleware) router.Use(compressor.Handler) serveStaticDir(router, webOpenAPIPath, s.openAPIPath, false) }) } } } func (s *httpdServer) setupWebClientRoutes() { if s.enableWebClient { s.router.Get(webBaseClientPath, func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) http.Redirect(w, r, webClientLoginPath, http.StatusFound) }) s.router.With(cleanCacheControlMiddleware).Get(path.Join(webStaticFilesPath, "branding/webclient/logo.png"), func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) renderPNGImage(w, r, dbBrandingConfig.getWebClientLogo()) }) s.router.With(cleanCacheControlMiddleware).Get(path.Join(webStaticFilesPath, "branding/webclient/favicon.png"), func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) renderPNGImage(w, r, dbBrandingConfig.getWebClientFavicon()) }) s.router.Get(webClientLoginPath, s.handleClientWebLogin) if s.binding.OIDC.isEnabled() && !s.binding.isWebClientOIDCLoginDisabled() { s.router.Get(webClientOIDCLoginPath, s.handleWebClientOIDCLogin) } if !s.binding.isWebClientLoginFormDisabled() { s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webClientLoginPath, s.handleWebClientLoginPost) s.router.Get(webClientForgotPwdPath, s.handleWebClientForgotPwd) s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost) s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Get(webClientResetPwdPath, s.handleWebClientPasswordReset) s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost) s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Get(webClientTwoFactorPath, s.handleWebClientTwoFactor) s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Post(webClientTwoFactorPath, s.handleWebClientTwoFactorPost) s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Get(webClientTwoFactorRecoveryPath, s.handleWebClientTwoFactorRecovery) s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Post(webClientTwoFactorRecoveryPath, s.handleWebClientTwoFactorRecoveryPost) } // share routes available to external users s.router.Get(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginGet) s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost) s.router.Get(webClientPubSharesPath+"/{id}/logout", s.handleClientShareLogout) s.router.Get(webClientPubSharesPath+"/{id}", s.downloadFromShare) s.router.Post(webClientPubSharesPath+"/{id}/partial", s.handleClientSharePartialDownload) s.router.Get(webClientPubSharesPath+"/{id}/browse", s.handleShareGetFiles) s.router.Post(webClientPubSharesPath+"/{id}/browse/exist", s.handleClientShareCheckExist) s.router.Get(webClientPubSharesPath+"/{id}/download", s.handleClientSharedFile) s.router.Get(webClientPubSharesPath+"/{id}/upload", s.handleClientUploadToShare) s.router.With(compressor.Handler).Get(webClientPubSharesPath+"/{id}/dirs", s.handleShareGetDirContents) s.router.Post(webClientPubSharesPath+"/{id}", s.uploadFilesToShare) s.router.Post(webClientPubSharesPath+"/{id}/{name}", s.uploadFileToShare) s.router.Get(webClientPubSharesPath+"/{id}/viewpdf", s.handleShareViewPDF) s.router.Get(webClientPubSharesPath+"/{id}/getpdf", s.handleShareGetPDF) s.router.Group(func(router chi.Router) { if s.binding.OIDC.isEnabled() { router.Use(s.oidcTokenAuthenticator(tokenAudienceWebClient)) } router.Use(jwt.Verify(s.tokenAuth, oidcTokenFromContext, jwt.TokenFromCookie)) router.Use(jwtAuthenticatorWebClient) router.Get(webClientLogoutPath, s.handleWebClientLogout) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientFilesPath, s.handleClientGetFiles) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientViewPDFPath, s.handleClientViewPDF) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientGetPDFPath, s.handleClientGetPDF) router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientFilePath, getUserFile) router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientTasksPath+"/{id}", getWebTask) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientFilePath, uploadUserFile) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientExistPath, s.handleClientCheckExist) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientEditFilePath, s.handleClientEditFile) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Delete(webClientFilesPath, deleteUserFile) router.With(s.checkAuthRequirements, compressor.Handler, s.refreshCookie). Get(webClientDirsPath, s.handleClientGetDirContents) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientDirsPath, createUserDir) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Delete(webClientDirsPath, taskDeleteDir) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientFileActionsPath+"/move", taskRenameFsEntry) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientFileActionsPath+"/copy", taskCopyFsEntry) router.With(s.checkAuthRequirements, s.refreshCookie). Post(webClientDownloadZipPath, s.handleWebClientDownloadZip) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientPingPath, handlePingRequest) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientProfilePath, s.handleClientGetProfile) router.With(s.checkAuthRequirements).Post(webClientProfilePath, s.handleWebClientProfilePost) router.With(s.checkHTTPUserPerm(sdk.WebClientPasswordChangeDisabled)). Get(webChangeClientPwdPath, s.handleWebClientChangePwd) router.With(s.checkHTTPUserPerm(sdk.WebClientPasswordChangeDisabled)). Post(webChangeClientPwdPath, s.handleWebClientChangePwdPost) router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.refreshCookie). Get(webClientMFAPath, s.handleWebClientMFA) router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.refreshCookie). Get(webClientMFAPath+"/qrcode", getQRCode) router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). Post(webClientTOTPGeneratePath, generateTOTPSecret) router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). Post(webClientTOTPValidatePath, validateTOTPPasscode) router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). Post(webClientTOTPSavePath, saveTOTPConfig) router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader, s.refreshCookie). Get(webClientRecoveryCodesPath, getRecoveryCodes) router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). Post(webClientRecoveryCodesPath, generateRecoveryCodes) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), compressor.Handler, s.refreshCookie). Get(webClientSharesPath+jsonAPISuffix, getAllShares) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.refreshCookie). Get(webClientSharesPath, s.handleClientGetShares) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.refreshCookie). Get(webClientSharePath, s.handleClientAddShareGet) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). Post(webClientSharePath, s.handleClientAddSharePost) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.refreshCookie). Get(webClientSharePath+"/{id}", s.handleClientUpdateShareGet) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). Post(webClientSharePath+"/{id}", s.handleClientUpdateSharePost) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.verifyCSRFHeader). Delete(webClientSharePath+"/{id}", deleteShare) }) } } func (s *httpdServer) setupWebAdminRoutes() { if s.enableWebAdmin { s.router.Get(webBaseAdminPath, func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) s.redirectToWebPath(w, r, webAdminLoginPath) }) s.router.With(cleanCacheControlMiddleware).Get(path.Join(webStaticFilesPath, "branding/webadmin/logo.png"), func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) renderPNGImage(w, r, dbBrandingConfig.getWebAdminLogo()) }) s.router.With(cleanCacheControlMiddleware).Get(path.Join(webStaticFilesPath, "branding/webadmin/favicon.png"), func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) renderPNGImage(w, r, dbBrandingConfig.getWebAdminFavicon()) }) s.router.Get(webAdminLoginPath, s.handleWebAdminLogin) if s.binding.OIDC.hasRoles() && !s.binding.isWebAdminOIDCLoginDisabled() { s.router.Get(webAdminOIDCLoginPath, s.handleWebAdminOIDCLogin) } s.router.Get(webOAuth2RedirectPath, s.handleOAuth2TokenRedirect) s.router.Get(webAdminSetupPath, s.handleWebAdminSetupGet) s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webAdminSetupPath, s.handleWebAdminSetupPost) if !s.binding.isWebAdminLoginFormDisabled() { s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webAdminLoginPath, s.handleWebAdminLoginPost) s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Get(webAdminTwoFactorPath, s.handleWebAdminTwoFactor) s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Post(webAdminTwoFactorPath, s.handleWebAdminTwoFactorPost) s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Get(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecovery) s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Post(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecoveryPost) s.router.Get(webAdminForgotPwdPath, s.handleWebAdminForgotPwd) s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost) s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset) s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost) } s.router.Group(func(router chi.Router) { if s.binding.OIDC.isEnabled() { router.Use(s.oidcTokenAuthenticator(tokenAudienceWebAdmin)) } router.Use(jwt.Verify(s.tokenAuth, oidcTokenFromContext, jwt.TokenFromCookie)) router.Use(jwtAuthenticatorWebAdmin) router.Get(webLogoutPath, s.handleWebAdminLogout) router.With(s.refreshCookie, s.checkAuthRequirements, s.requireBuiltinLogin).Get( webAdminProfilePath, s.handleWebAdminProfile) router.With(s.checkAuthRequirements, s.requireBuiltinLogin).Post(webAdminProfilePath, s.handleWebAdminProfilePost) router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webChangeAdminPwdPath, s.handleWebAdminChangePwd) router.With(s.requireBuiltinLogin).Post(webChangeAdminPwdPath, s.handleWebAdminChangePwdPost) router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath, s.handleWebAdminMFA) router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath+"/qrcode", getQRCode) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPGeneratePath, generateTOTPSecret) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPValidatePath, validateTOTPPasscode) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPSavePath, saveTOTPConfig) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin, s.refreshCookie).Get(webAdminRecoveryCodesPath, getRecoveryCodes) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminRecoveryCodesPath, generateRecoveryCodes) router.Group(func(router chi.Router) { router.Use(s.checkAuthRequirements) router.With(s.checkPerms(dataprovider.PermAdminViewUsers), s.refreshCookie). Get(webUsersPath, s.handleGetWebUsers) router.With(s.checkPerms(dataprovider.PermAdminViewUsers), compressor.Handler, s.refreshCookie). Get(webUsersPath+jsonAPISuffix, getAllUsers) router.With(s.checkPerms(dataprovider.PermAdminAddUsers), s.refreshCookie). Get(webUserPath, s.handleWebAddUserGet) router.With(s.checkPerms(dataprovider.PermAdminChangeUsers), s.refreshCookie). Get(webUserPath+"/{username}", s.handleWebUpdateUserGet) router.With(s.checkPerms(dataprovider.PermAdminAddUsers)).Post(webUserPath, s.handleWebAddUserPost) router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Post(webUserPath+"/{username}", s.handleWebUpdateUserPost) router.With(s.checkPerms(dataprovider.PermAdminManageGroups), s.refreshCookie). Get(webGroupsPath, s.handleWebGetGroups) router.With(s.checkPerms(dataprovider.PermAdminManageGroups), compressor.Handler, s.refreshCookie). Get(webGroupsPath+jsonAPISuffix, getAllGroups) router.With(s.checkPerms(dataprovider.PermAdminManageGroups), s.refreshCookie). Get(webGroupPath, s.handleWebAddGroupGet) router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Post(webGroupPath, s.handleWebAddGroupPost) router.With(s.checkPerms(dataprovider.PermAdminManageGroups), s.refreshCookie). Get(webGroupPath+"/{name}", s.handleWebUpdateGroupGet) router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Post(webGroupPath+"/{name}", s.handleWebUpdateGroupPost) router.With(s.checkPerms(dataprovider.PermAdminManageGroups), s.verifyCSRFHeader). Delete(webGroupPath+"/{name}", deleteGroup) router.With(s.checkPerms(dataprovider.PermAdminViewConnections), s.refreshCookie). Get(webConnectionsPath, s.handleWebGetConnections) router.With(s.checkPerms(dataprovider.PermAdminViewConnections), s.refreshCookie). Get(webConnectionsPath+jsonAPISuffix, getActiveConnections) router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.refreshCookie). Get(webFoldersPath, s.handleWebGetFolders) router.With(s.checkPerms(dataprovider.PermAdminManageFolders), compressor.Handler, s.refreshCookie). Get(webFoldersPath+jsonAPISuffix, getAllFolders) router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.refreshCookie). Get(webFolderPath, s.handleWebAddFolderGet) router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Post(webFolderPath, s.handleWebAddFolderPost) router.With(s.checkPerms(dataprovider.PermAdminViewServerStatus), s.refreshCookie). Get(webStatusPath, s.handleWebGetStatus) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminsPath, s.handleGetWebAdmins) router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). Get(webAdminsPath+jsonAPISuffix, getAllAdmins) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminPath, s.handleWebAddAdminGet) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminPath+"/{username}", s.handleWebUpdateAdminGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminPath, s.handleWebAddAdminPost) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminPath+"/{username}", s.handleWebUpdateAdminPost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). Delete(webAdminPath+"/{username}", deleteAdmin) router.With(s.checkPerms(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader). Put(webAdminPath+"/{username}/2fa/disable", disableAdmin2FA) router.With(s.checkPerms(dataprovider.PermAdminCloseConnections), s.verifyCSRFHeader). Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection) router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.refreshCookie). Get(webFolderPath+"/{name}", s.handleWebUpdateFolderGet) router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Post(webFolderPath+"/{name}", s.handleWebUpdateFolderPost) router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.verifyCSRFHeader). Delete(webFolderPath+"/{name}", deleteFolder) router.With(s.checkPerms(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader). Post(webScanVFolderPath+"/{name}", startFolderQuotaScan) router.With(s.checkPerms(dataprovider.PermAdminDeleteUsers), s.verifyCSRFHeader). Delete(webUserPath+"/{username}", deleteUser) router.With(s.checkPerms(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader). Put(webUserPath+"/{username}/2fa/disable", disableUser2FA) router.With(s.checkPerms(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader). Post(webQuotaScanPath+"/{username}", startUserQuotaScan) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(webMaintenancePath, s.handleWebMaintenance) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(webBackupPath, dumpData) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webRestorePath, s.handleWebRestore) router.With(s.checkPerms(dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers), s.refreshCookie). Get(webTemplateUser, s.handleWebTemplateUserGet) router.With(s.checkPerms(dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers)). Post(webTemplateUser, s.handleWebTemplateUserPost) router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.refreshCookie). Get(webTemplateFolder, s.handleWebTemplateFolderGet) router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Post(webTemplateFolder, s.handleWebTemplateFolderPost) router.With(s.checkPerms(dataprovider.PermAdminViewDefender)).Get(webDefenderPath, s.handleWebDefenderPage) router.With(s.checkPerms(dataprovider.PermAdminViewDefender)).Get(webDefenderHostsPath, getDefenderHosts) router.With(s.checkPerms(dataprovider.PermAdminManageDefender), s.verifyCSRFHeader). Delete(webDefenderHostsPath+"/{id}", deleteDefenderHostByID) router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). Get(webAdminEventActionsPath+jsonAPISuffix, getAllActions) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminEventActionsPath, s.handleWebGetEventActions) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminEventActionPath, s.handleWebAddEventActionGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminEventActionPath, s.handleWebAddEventActionPost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminEventActionPath+"/{name}", s.handleWebUpdateEventActionGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminEventActionPath+"/{name}", s.handleWebUpdateEventActionPost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). Delete(webAdminEventActionPath+"/{name}", deleteEventAction) router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). Get(webAdminEventRulesPath+jsonAPISuffix, getAllRules) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminEventRulesPath, s.handleWebGetEventRules) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminEventRulePath, s.handleWebAddEventRuleGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminEventRulePath, s.handleWebAddEventRulePost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminEventRulePath+"/{name}", s.handleWebUpdateEventRuleGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminEventRulePath+"/{name}", s.handleWebUpdateEventRulePost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). Delete(webAdminEventRulePath+"/{name}", deleteEventRule) router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). Post(webAdminEventRulePath+"/run/{name}", runOnDemandRule) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminRolesPath, s.handleWebGetRoles) router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). Get(webAdminRolesPath+jsonAPISuffix, getAllRoles) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminRolePath, s.handleWebAddRoleGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminRolePath, s.handleWebAddRolePost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). Get(webAdminRolePath+"/{name}", s.handleWebUpdateRoleGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminRolePath+"/{name}", s.handleWebUpdateRolePost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). Delete(webAdminRolePath+"/{name}", deleteRole) router.With(s.checkPerms(dataprovider.PermAdminViewEvents), s.refreshCookie).Get(webEventsPath, s.handleWebGetEvents) router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler, s.refreshCookie). Get(webEventsFsSearchPath, searchFsEvents) router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler, s.refreshCookie). Get(webEventsProviderSearchPath, searchProviderEvents) router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler, s.refreshCookie). Get(webEventsLogSearchPath, searchLogEvents) router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(webIPListsPath, s.handleWebIPListsPage) router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). Get(webIPListsPath+"/{type}", getIPListEntries) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie).Get(webIPListPath+"/{type}", s.handleWebAddIPListEntryGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webIPListPath+"/{type}", s.handleWebAddIPListEntryPost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie).Get(webIPListPath+"/{type}/{ipornet}", s.handleWebUpdateIPListEntryGet) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webIPListPath+"/{type}/{ipornet}", s.handleWebUpdateIPListEntryPost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). Delete(webIPListPath+"/{type}/{ipornet}", deleteIPListEntry) router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie).Get(webConfigsPath, s.handleWebConfigs) router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webConfigsPath, s.handleWebConfigsPost) router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader, s.refreshCookie). Post(webConfigsPath+"/smtp/test", testSMTPConfig) router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader, s.refreshCookie). Post(webOAuth2TokenPath, s.handleSMTPOAuth2TokenRequestPost) }) }) } } ================================================ FILE: internal/httpd/token.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "crypto/sha256" "encoding/hex" "sync" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) func newTokenManager(isShared int) tokenManager { if isShared == 1 { logger.Info(logSender, "", "using provider token manager") return &dbTokenManager{} } logger.Info(logSender, "", "using memory token manager") return &memoryTokenManager{} } type tokenManager interface { Add(token string, expiresAt time.Time) Get(token string) bool Cleanup() } type memoryTokenManager struct { invalidatedJWTTokens sync.Map } func (m *memoryTokenManager) Add(token string, expiresAt time.Time) { m.invalidatedJWTTokens.Store(token, expiresAt) } func (m *memoryTokenManager) Get(token string) bool { _, ok := m.invalidatedJWTTokens.Load(token) return ok } func (m *memoryTokenManager) Cleanup() { m.invalidatedJWTTokens.Range(func(key, value any) bool { exp, ok := value.(time.Time) if !ok || exp.Before(time.Now().UTC()) { m.invalidatedJWTTokens.Delete(key) } return true }) } type dbTokenManager struct{} func (m *dbTokenManager) getKey(token string) string { digest := sha256.Sum256([]byte(token)) return hex.EncodeToString(digest[:]) } func (m *dbTokenManager) Add(token string, expiresAt time.Time) { key := m.getKey(token) data := map[string]string{ "jwt": token, } session := dataprovider.Session{ Key: key, Data: data, Type: dataprovider.SessionTypeInvalidToken, Timestamp: util.GetTimeAsMsSinceEpoch(expiresAt), } dataprovider.AddSharedSession(session) //nolint:errcheck } func (m *dbTokenManager) Get(token string) bool { key := m.getKey(token) _, err := dataprovider.GetSharedSession(key, dataprovider.SessionTypeInvalidToken) return err == nil } func (m *dbTokenManager) Cleanup() { dataprovider.CleanupSharedSessions(dataprovider.SessionTypeInvalidToken, time.Now()) //nolint:errcheck } ================================================ FILE: internal/httpd/web.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "errors" "net/http" "strings" "github.com/go-chi/render" "github.com/unrolled/secure" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( webDateTimeFormat = "2006-01-02 15:04:05" // YYYY-MM-DD HH:MM:SS redactedSecret = "[**redacted**]" csrfFormToken = "_form_token" csrfHeaderToken = "X-CSRF-TOKEN" templateCommonDir = "common" templateTwoFactor = "twofactor.html" templateTwoFactorRecovery = "twofactor-recovery.html" templateForgotPassword = "forgot-password.html" templateResetPassword = "reset-password.html" templateChangePwd = "changepassword.html" templateMessage = "message.html" templateCommonBase = "base.html" templateCommonBaseLogin = "baselogin.html" templateCommonLogin = "login.html" ) var ( errInvalidTokenClaims = errors.New("invalid token claims") ) type commonBasePage struct { CSPNonce string StaticURL string Version string } type loginPage struct { commonBasePage CurrentURL string Error *util.I18nError CSRFToken string AltLoginURL string AltLoginName string ForgotPwdURL string OpenIDLoginURL string Title string Branding UIBranding Languages []string FormDisabled bool CheckRedirect bool } type twoFactorPage struct { commonBasePage CurrentURL string Error *util.I18nError CSRFToken string RecoveryURL string Title string Branding UIBranding Languages []string CheckRedirect bool } type forgotPwdPage struct { commonBasePage CurrentURL string Error *util.I18nError CSRFToken string LoginURL string Title string Branding UIBranding Languages []string CheckRedirect bool } type resetPwdPage struct { commonBasePage CurrentURL string Error *util.I18nError CSRFToken string LoginURL string Title string Branding UIBranding Languages []string CheckRedirect bool } func getSliceFromDelimitedValues(values, delimiter string) []string { result := []string{} for v := range strings.SplitSeq(values, delimiter) { cleaned := strings.TrimSpace(v) if cleaned != "" { result = append(result, cleaned) } } return result } func hasPrefixAndSuffix(key, prefix, suffix string) bool { return strings.HasPrefix(key, prefix) && strings.HasSuffix(key, suffix) } func getCommonBasePage(r *http.Request) commonBasePage { return commonBasePage{ CSPNonce: secure.CSPNonce(r.Context()), StaticURL: webStaticFilesPath, Version: version.GetServerVersion(" ", true), } } func i18nListDirMsg(status int) string { if status == http.StatusForbidden { return util.I18nErrorDirList403 } return util.I18nErrorDirListGeneric } func i18nFsMsg(status int) string { if status == http.StatusForbidden { return util.I18nError403Message } return util.I18nErrorFsGeneric } func getI18NErrorString(err error, fallback string) string { var errI18n *util.I18nError if errors.As(err, &errI18n) { return errI18n.Message } return fallback } func getI18nError(err error) *util.I18nError { var errI18n *util.I18nError if err != nil { errI18n = util.NewI18nError(err, util.I18nError500Message) } return errI18n } func handlePingRequest(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) render.PlainText(w, r, "PONG") } ================================================ FILE: internal/httpd/webadmin.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "context" "crypto/rand" "encoding/json" "errors" "fmt" "html/template" "io" "net/http" "net/url" "os" "path/filepath" "slices" "sort" "strconv" "strings" "time" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" "golang.org/x/oauth2" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) type userPageMode int const ( userPageModeAdd userPageMode = iota + 1 userPageModeUpdate userPageModeTemplate ) type folderPageMode int const ( folderPageModeAdd folderPageMode = iota + 1 folderPageModeUpdate folderPageModeTemplate ) type genericPageMode int const ( genericPageModeAdd genericPageMode = iota + 1 genericPageModeUpdate ) const ( templateAdminDir = "webadmin" templateBase = "base.html" templateFsConfig = "fsconfig.html" templateUsers = "users.html" templateUser = "user.html" templateAdmins = "admins.html" templateAdmin = "admin.html" templateConnections = "connections.html" templateGroups = "groups.html" templateGroup = "group.html" templateFolders = "folders.html" templateFolder = "folder.html" templateEventRules = "eventrules.html" templateEventRule = "eventrule.html" templateEventActions = "eventactions.html" templateEventAction = "eventaction.html" templateRoles = "roles.html" templateRole = "role.html" templateEvents = "events.html" templateStatus = "status.html" templateDefender = "defender.html" templateIPLists = "iplists.html" templateIPList = "iplist.html" templateConfigs = "configs.html" templateProfile = "profile.html" templateMaintenance = "maintenance.html" templateMFA = "mfa.html" templateSetup = "adminsetup.html" defaultQueryLimit = 1000 inversePatternType = "inverse" ) var ( adminTemplates = make(map[string]*template.Template) ) type basePage struct { commonBasePage Title string CurrentURL string UsersURL string UserURL string UserTemplateURL string AdminsURL string AdminURL string QuotaScanURL string ConnectionsURL string GroupsURL string GroupURL string FoldersURL string FolderURL string FolderTemplateURL string DefenderURL string IPListsURL string IPListURL string EventsURL string ConfigsURL string LogoutURL string LoginURL string ProfileURL string ChangePwdURL string MFAURL string EventRulesURL string EventRuleURL string EventActionsURL string EventActionURL string RolesURL string RoleURL string FolderQuotaScanURL string StatusURL string MaintenanceURL string CSRFToken string IsEventManagerPage bool IsIPManagerPage bool IsServerManagerPage bool HasDefender bool HasSearcher bool HasExternalLogin bool LoggedUser *dataprovider.Admin IsLoggedToShare bool Branding UIBranding Languages []string } type statusPage struct { basePage Status *ServicesStatus } type fsWrapper struct { vfs.Filesystem IsUserPage bool IsGroupPage bool IsHidden bool HasUsersBaseDir bool DirPath string } type userPage struct { basePage User *dataprovider.User RootPerms []string Error *util.I18nError ValidPerms []string ValidLoginMethods []string ValidProtocols []string TwoFactorProtocols []string WebClientOptions []string RootDirPerms []string Mode userPageMode VirtualFolders []vfs.BaseVirtualFolder Groups []dataprovider.Group Roles []dataprovider.Role CanImpersonate bool FsWrapper fsWrapper CanUseTLSCerts bool } type adminPage struct { basePage Admin *dataprovider.Admin Groups []dataprovider.Group Roles []dataprovider.Role Error *util.I18nError IsAdd bool } type profilePage struct { basePage Error *util.I18nError AllowAPIKeyAuth bool Email string Description string } type changePasswordPage struct { basePage Error *util.I18nError } type mfaPage struct { basePage TOTPConfigs []string TOTPConfig dataprovider.AdminTOTPConfig GenerateTOTPURL string ValidateTOTPURL string SaveTOTPURL string RecCodesURL string RequireTwoFactor bool } type maintenancePage struct { basePage BackupPath string RestorePath string Error *util.I18nError } type defenderHostsPage struct { basePage DefenderHostsURL string } type ipListsPage struct { basePage IPListsSearchURL string RateLimitersStatus bool RateLimitersProtocols string IsAllowListEnabled bool } type ipListPage struct { basePage Entry *dataprovider.IPListEntry Error *util.I18nError Mode genericPageMode } type setupPage struct { commonBasePage CurrentURL string Error *util.I18nError CSRFToken string Username string HasInstallationCode bool InstallationCodeHint string HideSupportLink bool Title string Branding UIBranding Languages []string CheckRedirect bool } type folderPage struct { basePage Folder vfs.BaseVirtualFolder Error *util.I18nError Mode folderPageMode FsWrapper fsWrapper } type groupPage struct { basePage Group *dataprovider.Group Error *util.I18nError Mode genericPageMode ValidPerms []string ValidLoginMethods []string ValidProtocols []string TwoFactorProtocols []string WebClientOptions []string VirtualFolders []vfs.BaseVirtualFolder FsWrapper fsWrapper } type rolePage struct { basePage Role *dataprovider.Role Error *util.I18nError Mode genericPageMode } type eventActionPage struct { basePage Action dataprovider.BaseEventAction ActionTypes []dataprovider.EnumMapping FsActions []dataprovider.EnumMapping HTTPMethods []string EnabledCommands []string RedactedSecret string Error *util.I18nError Mode genericPageMode } type eventRulePage struct { basePage Rule dataprovider.EventRule TriggerTypes []dataprovider.EnumMapping Actions []dataprovider.BaseEventAction FsEvents []string Protocols []string ProviderEvents []string ProviderObjects []string Error *util.I18nError Mode genericPageMode IsShared bool } type eventsPage struct { basePage FsEventsSearchURL string ProviderEventsSearchURL string LogEventsSearchURL string } type configsPage struct { basePage Configs dataprovider.Configs ConfigSection int RedactedSecret string OAuth2TokenURL string OAuth2RedirectURL string WebClientBranding UIBranding Error *util.I18nError } type messagePage struct { basePage Error *util.I18nError Success string Text string } type userTemplateFields struct { Username string Password string PublicKeys []string RequirePwdChange bool } func loadAdminTemplates(templatesPath string) { usersPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateUsers), } userPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateFsConfig), filepath.Join(templatesPath, templateAdminDir, templateUser), } adminsPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateAdmins), } adminPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateAdmin), } profilePaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateProfile), } changePwdPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateCommonDir, templateChangePwd), } connectionsPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateConnections), } messagePaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateCommonDir, templateMessage), } foldersPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateFolders), } folderPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateFsConfig), filepath.Join(templatesPath, templateAdminDir, templateFolder), } groupsPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateGroups), } groupPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateFsConfig), filepath.Join(templatesPath, templateAdminDir, templateGroup), } eventRulesPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateEventRules), } eventRulePaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateEventRule), } eventActionsPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateEventActions), } eventActionPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateEventAction), } statusPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateStatus), } loginPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateCommonLogin), } maintenancePaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateMaintenance), } defenderPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateDefender), } ipListsPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateIPLists), } ipListPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateIPList), } mfaPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateMFA), } twoFactorPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateTwoFactor), } twoFactorRecoveryPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateTwoFactorRecovery), } setupPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateAdminDir, templateSetup), } forgotPwdPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateForgotPassword), } resetPwdPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateResetPassword), } rolesPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateRoles), } rolePaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateRole), } eventsPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateEvents), } configsPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateAdminDir, templateBase), filepath.Join(templatesPath, templateAdminDir, templateConfigs), } fsBaseTpl := template.New("fsBaseTemplate").Funcs(template.FuncMap{ "HumanizeBytes": util.ByteCountSI, }) usersTmpl := util.LoadTemplate(nil, usersPaths...) userTmpl := util.LoadTemplate(fsBaseTpl, userPaths...) adminsTmpl := util.LoadTemplate(nil, adminsPaths...) adminTmpl := util.LoadTemplate(nil, adminPaths...) connectionsTmpl := util.LoadTemplate(nil, connectionsPaths...) messageTmpl := util.LoadTemplate(nil, messagePaths...) groupsTmpl := util.LoadTemplate(nil, groupsPaths...) groupTmpl := util.LoadTemplate(fsBaseTpl, groupPaths...) foldersTmpl := util.LoadTemplate(nil, foldersPaths...) folderTmpl := util.LoadTemplate(fsBaseTpl, folderPaths...) eventRulesTmpl := util.LoadTemplate(nil, eventRulesPaths...) eventRuleTmpl := util.LoadTemplate(fsBaseTpl, eventRulePaths...) eventActionsTmpl := util.LoadTemplate(nil, eventActionsPaths...) eventActionTmpl := util.LoadTemplate(nil, eventActionPaths...) statusTmpl := util.LoadTemplate(nil, statusPaths...) loginTmpl := util.LoadTemplate(nil, loginPaths...) profileTmpl := util.LoadTemplate(nil, profilePaths...) changePwdTmpl := util.LoadTemplate(nil, changePwdPaths...) maintenanceTmpl := util.LoadTemplate(nil, maintenancePaths...) defenderTmpl := util.LoadTemplate(nil, defenderPaths...) ipListsTmpl := util.LoadTemplate(nil, ipListsPaths...) ipListTmpl := util.LoadTemplate(nil, ipListPaths...) mfaTmpl := util.LoadTemplate(nil, mfaPaths...) twoFactorTmpl := util.LoadTemplate(nil, twoFactorPaths...) twoFactorRecoveryTmpl := util.LoadTemplate(nil, twoFactorRecoveryPaths...) setupTmpl := util.LoadTemplate(nil, setupPaths...) forgotPwdTmpl := util.LoadTemplate(nil, forgotPwdPaths...) resetPwdTmpl := util.LoadTemplate(nil, resetPwdPaths...) rolesTmpl := util.LoadTemplate(nil, rolesPaths...) roleTmpl := util.LoadTemplate(nil, rolePaths...) eventsTmpl := util.LoadTemplate(nil, eventsPaths...) configsTmpl := util.LoadTemplate(nil, configsPaths...) adminTemplates[templateUsers] = usersTmpl adminTemplates[templateUser] = userTmpl adminTemplates[templateAdmins] = adminsTmpl adminTemplates[templateAdmin] = adminTmpl adminTemplates[templateConnections] = connectionsTmpl adminTemplates[templateMessage] = messageTmpl adminTemplates[templateGroups] = groupsTmpl adminTemplates[templateGroup] = groupTmpl adminTemplates[templateFolders] = foldersTmpl adminTemplates[templateFolder] = folderTmpl adminTemplates[templateEventRules] = eventRulesTmpl adminTemplates[templateEventRule] = eventRuleTmpl adminTemplates[templateEventActions] = eventActionsTmpl adminTemplates[templateEventAction] = eventActionTmpl adminTemplates[templateStatus] = statusTmpl adminTemplates[templateCommonLogin] = loginTmpl adminTemplates[templateProfile] = profileTmpl adminTemplates[templateChangePwd] = changePwdTmpl adminTemplates[templateMaintenance] = maintenanceTmpl adminTemplates[templateDefender] = defenderTmpl adminTemplates[templateIPLists] = ipListsTmpl adminTemplates[templateIPList] = ipListTmpl adminTemplates[templateMFA] = mfaTmpl adminTemplates[templateTwoFactor] = twoFactorTmpl adminTemplates[templateTwoFactorRecovery] = twoFactorRecoveryTmpl adminTemplates[templateSetup] = setupTmpl adminTemplates[templateForgotPassword] = forgotPwdTmpl adminTemplates[templateResetPassword] = resetPwdTmpl adminTemplates[templateRoles] = rolesTmpl adminTemplates[templateRole] = roleTmpl adminTemplates[templateEvents] = eventsTmpl adminTemplates[templateConfigs] = configsTmpl } func isEventManagerResource(currentURL string) bool { if currentURL == webAdminEventRulesPath { return true } if currentURL == webAdminEventActionsPath { return true } if currentURL == webAdminEventRulePath || strings.HasPrefix(currentURL, webAdminEventRulePath+"/") { return true } if currentURL == webAdminEventActionPath || strings.HasPrefix(currentURL, webAdminEventActionPath+"/") { return true } return false } func isIPListsResource(currentURL string) bool { if currentURL == webDefenderPath { return true } if currentURL == webIPListsPath { return true } if strings.HasPrefix(currentURL, webIPListPath+"/") { return true } return false } func isServerManagerResource(currentURL string) bool { return currentURL == webEventsPath || currentURL == webStatusPath || currentURL == webMaintenancePath || currentURL == webConfigsPath } func (s *httpdServer) getBasePageData(title, currentURL string, w http.ResponseWriter, r *http.Request) basePage { var csrfToken string if currentURL != "" { csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath) } return basePage{ commonBasePage: getCommonBasePage(r), Title: title, CurrentURL: currentURL, UsersURL: webUsersPath, UserURL: webUserPath, UserTemplateURL: webTemplateUser, AdminsURL: webAdminsPath, AdminURL: webAdminPath, GroupsURL: webGroupsPath, GroupURL: webGroupPath, FoldersURL: webFoldersPath, FolderURL: webFolderPath, FolderTemplateURL: webTemplateFolder, DefenderURL: webDefenderPath, IPListsURL: webIPListsPath, IPListURL: webIPListPath, EventsURL: webEventsPath, ConfigsURL: webConfigsPath, LogoutURL: webLogoutPath, LoginURL: webAdminLoginPath, ProfileURL: webAdminProfilePath, ChangePwdURL: webChangeAdminPwdPath, MFAURL: webAdminMFAPath, EventRulesURL: webAdminEventRulesPath, EventRuleURL: webAdminEventRulePath, EventActionsURL: webAdminEventActionsPath, EventActionURL: webAdminEventActionPath, RolesURL: webAdminRolesPath, RoleURL: webAdminRolePath, QuotaScanURL: webQuotaScanPath, ConnectionsURL: webConnectionsPath, StatusURL: webStatusPath, FolderQuotaScanURL: webScanVFolderPath, MaintenanceURL: webMaintenancePath, LoggedUser: getAdminFromToken(r), IsEventManagerPage: isEventManagerResource(currentURL), IsIPManagerPage: isIPListsResource(currentURL), IsServerManagerPage: isServerManagerResource(currentURL), HasDefender: common.Config.DefenderConfig.Enabled, HasSearcher: plugin.Handler.HasSearcher(), HasExternalLogin: isLoggedInWithOIDC(r), CSRFToken: csrfToken, Branding: s.binding.webAdminBranding(), Languages: s.binding.languages(), } } func renderAdminTemplate(w http.ResponseWriter, tmplName string, data any) { err := adminTemplates[tmplName].ExecuteTemplate(w, tmplName, data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } func (s *httpdServer) renderMessagePageWithString(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message, text string, ) { data := messagePage{ basePage: s.getBasePageData(title, "", w, r), Error: getI18nError(err), Success: message, Text: text, } w.WriteHeader(statusCode) renderAdminTemplate(w, templateMessage, data) } func (s *httpdServer) renderMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message string, ) { s.renderMessagePageWithString(w, r, title, statusCode, err, message, "") } func (s *httpdServer) renderInternalServerErrorPage(w http.ResponseWriter, r *http.Request, err error) { s.renderMessagePage(w, r, util.I18nError500Title, http.StatusInternalServerError, util.NewI18nError(err, util.I18nError500Message), "") } func (s *httpdServer) renderBadRequestPage(w http.ResponseWriter, r *http.Request, err error) { s.renderMessagePage(w, r, util.I18nError400Title, http.StatusBadRequest, util.NewI18nError(err, util.I18nError400Message), "") } func (s *httpdServer) renderForbiddenPage(w http.ResponseWriter, r *http.Request, err error) { s.renderMessagePage(w, r, util.I18nError403Title, http.StatusForbidden, util.NewI18nError(err, util.I18nError403Message), "") } func (s *httpdServer) renderNotFoundPage(w http.ResponseWriter, r *http.Request, err error) { s.renderMessagePage(w, r, util.I18nError404Title, http.StatusNotFound, util.NewI18nError(err, util.I18nError404Message), "") } func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := forgotPwdPage{ commonBasePage: getCommonBasePage(r), CurrentURL: webAdminForgotPwdPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), LoginURL: webAdminLoginPath, Title: util.I18nForgotPwdTitle, Branding: s.binding.webAdminBranding(), Languages: s.binding.languages(), } renderAdminTemplate(w, templateForgotPassword, data) } func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := resetPwdPage{ commonBasePage: getCommonBasePage(r), CurrentURL: webAdminResetPwdPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath), LoginURL: webAdminLoginPath, Title: util.I18nResetPwdTitle, Branding: s.binding.webAdminBranding(), Languages: s.binding.languages(), } renderAdminTemplate(w, templateResetPassword, data) } func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := twoFactorPage{ commonBasePage: getCommonBasePage(r), Title: util.I18n2FATitle, CurrentURL: webAdminTwoFactorPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath), RecoveryURL: webAdminTwoFactorRecoveryPath, Branding: s.binding.webAdminBranding(), Languages: s.binding.languages(), } renderAdminTemplate(w, templateTwoFactor, data) } func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := twoFactorPage{ commonBasePage: getCommonBasePage(r), Title: util.I18n2FATitle, CurrentURL: webAdminTwoFactorRecoveryPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath), Branding: s.binding.webAdminBranding(), Languages: s.binding.languages(), } renderAdminTemplate(w, templateTwoFactorRecovery, data) } func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) { data := mfaPage{ basePage: s.getBasePageData(util.I18n2FATitle, webAdminMFAPath, w, r), TOTPConfigs: mfa.GetAvailableTOTPConfigNames(), GenerateTOTPURL: webAdminTOTPGeneratePath, ValidateTOTPURL: webAdminTOTPValidatePath, SaveTOTPURL: webAdminTOTPSavePath, RecCodesURL: webAdminRecoveryCodesPath, } admin, err := dataprovider.AdminExists(data.LoggedUser.Username) if err != nil { s.renderInternalServerErrorPage(w, r, err) return } data.TOTPConfig = admin.Filters.TOTPConfig data.RequireTwoFactor = admin.Filters.RequireTwoFactor renderAdminTemplate(w, templateMFA, data) } func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request, err error) { data := profilePage{ basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, w, r), Error: getI18nError(err), } admin, err := dataprovider.AdminExists(data.LoggedUser.Username) if err != nil { s.renderInternalServerErrorPage(w, r, err) return } data.AllowAPIKeyAuth = admin.Filters.AllowAPIKeyAuth data.Email = admin.Email data.Description = admin.Description renderAdminTemplate(w, templateProfile, data) } func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := changePasswordPage{ basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, w, r), Error: err, } renderAdminTemplate(w, templateChangePwd, data) } func (s *httpdServer) renderMaintenancePage(w http.ResponseWriter, r *http.Request, err error) { data := maintenancePage{ basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, w, r), BackupPath: webBackupPath, RestorePath: webRestorePath, Error: getI18nError(err), } renderAdminTemplate(w, templateMaintenance, data) } func (s *httpdServer) renderConfigsPage(w http.ResponseWriter, r *http.Request, configs dataprovider.Configs, err error, section int, ) { configs.SetNilsToEmpty() if configs.SMTP.Port == 0 { configs.SMTP.Port = 587 configs.SMTP.AuthType = 1 configs.SMTP.Encryption = 2 } if configs.ACME.HTTP01Challenge.Port == 0 { configs.ACME.HTTP01Challenge.Port = 80 } data := configsPage{ basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, w, r), Configs: configs, ConfigSection: section, RedactedSecret: redactedSecret, OAuth2TokenURL: webOAuth2TokenPath, OAuth2RedirectURL: webOAuth2RedirectPath, WebClientBranding: s.binding.webClientBranding(), Error: getI18nError(err), } renderAdminTemplate(w, templateConfigs, data) } func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username string, err *util.I18nError) { data := setupPage{ commonBasePage: getCommonBasePage(r), Title: util.I18nSetupTitle, CurrentURL: webAdminSetupPath, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), Username: username, HasInstallationCode: installationCode != "", InstallationCodeHint: installationCodeHint, HideSupportLink: hideSupportLink, Error: err, Branding: s.binding.webAdminBranding(), Languages: s.binding.languages(), } renderAdminTemplate(w, templateSetup, data) } func (s *httpdServer) renderAddUpdateAdminPage(w http.ResponseWriter, r *http.Request, admin *dataprovider.Admin, err error, isAdd bool) { groups, errGroups := s.getWebGroups(w, r, defaultQueryLimit, true) if errGroups != nil { return } roles, errRoles := s.getWebRoles(w, r, 10, true) if errRoles != nil { return } currentURL := webAdminPath title := util.I18nAddAdminTitle if !isAdd { currentURL = fmt.Sprintf("%v/%v", webAdminPath, url.PathEscape(admin.Username)) title = util.I18nUpdateAdminTitle } data := adminPage{ basePage: s.getBasePageData(title, currentURL, w, r), Admin: admin, Groups: groups, Roles: roles, Error: getI18nError(err), IsAdd: isAdd, } renderAdminTemplate(w, templateAdmin, data) } func (s *httpdServer) getUserPageTitleAndURL(mode userPageMode, username string) (string, string) { var title, currentURL string switch mode { case userPageModeAdd: title = util.I18nAddUserTitle currentURL = webUserPath case userPageModeUpdate: title = util.I18nUpdateUserTitle currentURL = fmt.Sprintf("%v/%v", webUserPath, url.PathEscape(username)) case userPageModeTemplate: title = util.I18nTemplateUserTitle currentURL = webTemplateUser } return title, currentURL } func (s *httpdServer) renderUserPage(w http.ResponseWriter, r *http.Request, user *dataprovider.User, mode userPageMode, err error, admin *dataprovider.Admin, ) { user.SetEmptySecretsIfNil() title, currentURL := s.getUserPageTitleAndURL(mode, user.Username) if user.Password != "" && user.IsPasswordHashed() { switch mode { case userPageModeUpdate: user.Password = redactedSecret default: user.Password = "" } } user.FsConfig.RedactedSecret = redactedSecret basePage := s.getBasePageData(title, currentURL, w, r) if (mode == userPageModeAdd || mode == userPageModeTemplate) && len(user.Groups) == 0 && admin != nil { for _, group := range admin.Groups { user.Groups = append(user.Groups, sdk.GroupMapping{ Name: group.Name, Type: group.Options.GetUserGroupType(), }) } } var roles []dataprovider.Role if basePage.LoggedUser.Role == "" { var errRoles error roles, errRoles = s.getWebRoles(w, r, 10, true) if errRoles != nil { return } } folders, errFolders := s.getWebVirtualFolders(w, r, defaultQueryLimit, true) if errFolders != nil { return } groups, errGroups := s.getWebGroups(w, r, defaultQueryLimit, true) if errGroups != nil { return } data := userPage{ basePage: basePage, Mode: mode, Error: getI18nError(err), User: user, ValidPerms: dataprovider.ValidPerms, ValidLoginMethods: dataprovider.ValidLoginMethods, ValidProtocols: dataprovider.ValidProtocols, TwoFactorProtocols: dataprovider.MFAProtocols, WebClientOptions: sdk.WebClientOptions, RootDirPerms: user.GetPermissionsForPath("/"), VirtualFolders: folders, Groups: groups, Roles: roles, CanImpersonate: os.Getuid() == 0, CanUseTLSCerts: ftpd.GetStatus().IsActive || webdavd.GetStatus().IsActive, FsWrapper: fsWrapper{ Filesystem: user.FsConfig, IsUserPage: true, IsGroupPage: false, IsHidden: basePage.LoggedUser.Filters.Preferences.HideFilesystem(), HasUsersBaseDir: dataprovider.HasUsersBaseDir(), DirPath: user.HomeDir, }, } renderAdminTemplate(w, templateUser, data) } func (s *httpdServer) renderIPListPage(w http.ResponseWriter, r *http.Request, entry dataprovider.IPListEntry, mode genericPageMode, err error, ) { var title, currentURL string switch mode { case genericPageModeAdd: title = util.I18nAddIPListTitle currentURL = fmt.Sprintf("%s/%d", webIPListPath, entry.Type) case genericPageModeUpdate: title = util.I18nUpdateIPListTitle currentURL = fmt.Sprintf("%s/%d/%s", webIPListPath, entry.Type, url.PathEscape(entry.IPOrNet)) } data := ipListPage{ basePage: s.getBasePageData(title, currentURL, w, r), Error: getI18nError(err), Entry: &entry, Mode: mode, } renderAdminTemplate(w, templateIPList, data) } func (s *httpdServer) renderRolePage(w http.ResponseWriter, r *http.Request, role dataprovider.Role, mode genericPageMode, err error, ) { var title, currentURL string switch mode { case genericPageModeAdd: title = util.I18nRoleAddTitle currentURL = webAdminRolePath case genericPageModeUpdate: title = util.I18nRoleUpdateTitle currentURL = fmt.Sprintf("%s/%s", webAdminRolePath, url.PathEscape(role.Name)) } data := rolePage{ basePage: s.getBasePageData(title, currentURL, w, r), Error: getI18nError(err), Role: &role, Mode: mode, } renderAdminTemplate(w, templateRole, data) } func (s *httpdServer) renderGroupPage(w http.ResponseWriter, r *http.Request, group dataprovider.Group, mode genericPageMode, err error, ) { folders, errFolders := s.getWebVirtualFolders(w, r, defaultQueryLimit, true) if errFolders != nil { return } group.SetEmptySecretsIfNil() group.UserSettings.FsConfig.RedactedSecret = redactedSecret var title, currentURL string switch mode { case genericPageModeAdd: title = util.I18nAddGroupTitle currentURL = webGroupPath case genericPageModeUpdate: title = util.I18nUpdateGroupTitle currentURL = fmt.Sprintf("%v/%v", webGroupPath, url.PathEscape(group.Name)) } group.UserSettings.FsConfig.RedactedSecret = redactedSecret group.UserSettings.FsConfig.SetEmptySecretsIfNil() data := groupPage{ basePage: s.getBasePageData(title, currentURL, w, r), Error: getI18nError(err), Group: &group, Mode: mode, ValidPerms: dataprovider.ValidPerms, ValidLoginMethods: dataprovider.ValidLoginMethods, ValidProtocols: dataprovider.ValidProtocols, TwoFactorProtocols: dataprovider.MFAProtocols, WebClientOptions: sdk.WebClientOptions, VirtualFolders: folders, FsWrapper: fsWrapper{ Filesystem: group.UserSettings.FsConfig, IsUserPage: false, IsGroupPage: true, HasUsersBaseDir: false, DirPath: group.UserSettings.HomeDir, }, } renderAdminTemplate(w, templateGroup, data) } func (s *httpdServer) renderEventActionPage(w http.ResponseWriter, r *http.Request, action dataprovider.BaseEventAction, mode genericPageMode, err error, ) { action.Options.SetEmptySecretsIfNil() var title, currentURL string switch mode { case genericPageModeAdd: title = util.I18nAddActionTitle currentURL = webAdminEventActionPath case genericPageModeUpdate: title = util.I18nUpdateActionTitle currentURL = fmt.Sprintf("%s/%s", webAdminEventActionPath, url.PathEscape(action.Name)) } if action.Options.HTTPConfig.Timeout == 0 { action.Options.HTTPConfig.Timeout = 20 } if action.Options.CmdConfig.Timeout == 0 { action.Options.CmdConfig.Timeout = 20 } if action.Options.PwdExpirationConfig.Threshold == 0 { action.Options.PwdExpirationConfig.Threshold = 10 } data := eventActionPage{ basePage: s.getBasePageData(title, currentURL, w, r), Action: action, ActionTypes: dataprovider.EventActionTypes, FsActions: dataprovider.FsActionTypes, HTTPMethods: dataprovider.SupportedHTTPActionMethods, EnabledCommands: dataprovider.EnabledActionCommands, RedactedSecret: redactedSecret, Error: getI18nError(err), Mode: mode, } renderAdminTemplate(w, templateEventAction, data) } func (s *httpdServer) renderEventRulePage(w http.ResponseWriter, r *http.Request, rule dataprovider.EventRule, mode genericPageMode, err error, ) { actions, errActions := s.getWebEventActions(w, r, defaultQueryLimit, true) if errActions != nil { return } var title, currentURL string switch mode { case genericPageModeAdd: title = util.I18nAddRuleTitle currentURL = webAdminEventRulePath case genericPageModeUpdate: title = util.I18nUpdateRuleTitle currentURL = fmt.Sprintf("%v/%v", webAdminEventRulePath, url.PathEscape(rule.Name)) } data := eventRulePage{ basePage: s.getBasePageData(title, currentURL, w, r), Rule: rule, TriggerTypes: dataprovider.EventTriggerTypes, Actions: actions, FsEvents: dataprovider.SupportedFsEvents, Protocols: dataprovider.SupportedRuleConditionProtocols, ProviderEvents: dataprovider.SupportedProviderEvents, ProviderObjects: dataprovider.SupporteRuleConditionProviderObjects, Error: getI18nError(err), Mode: mode, IsShared: s.isShared > 0, } renderAdminTemplate(w, templateEventRule, data) } func (s *httpdServer) renderFolderPage(w http.ResponseWriter, r *http.Request, folder vfs.BaseVirtualFolder, mode folderPageMode, err error, ) { var title, currentURL string switch mode { case folderPageModeAdd: title = util.I18nAddFolderTitle currentURL = webFolderPath case folderPageModeUpdate: title = util.I18nUpdateFolderTitle currentURL = fmt.Sprintf("%v/%v", webFolderPath, url.PathEscape(folder.Name)) case folderPageModeTemplate: title = util.I18nTemplateFolderTitle currentURL = webTemplateFolder } folder.FsConfig.RedactedSecret = redactedSecret folder.FsConfig.SetEmptySecretsIfNil() data := folderPage{ basePage: s.getBasePageData(title, currentURL, w, r), Error: getI18nError(err), Folder: folder, Mode: mode, FsWrapper: fsWrapper{ Filesystem: folder.FsConfig, IsUserPage: false, IsGroupPage: false, HasUsersBaseDir: false, DirPath: folder.MappedPath, }, } renderAdminTemplate(w, templateFolder, data) } func getFoldersForTemplate(r *http.Request) []string { var res []string for k := range r.Form { if hasPrefixAndSuffix(k, "template_folders[", "][tpl_foldername]") { r.Form.Add("tpl_foldername", r.Form.Get(k)) } } folderNames := r.Form["tpl_foldername"] folders := make(map[string]bool) for _, name := range folderNames { name = strings.TrimSpace(name) if name == "" { continue } if _, ok := folders[name]; ok { continue } folders[name] = true res = append(res, name) } return res } func getUsersForTemplate(r *http.Request) []userTemplateFields { var res []userTemplateFields tplUsernames := r.Form["tpl_username"] tplPasswords := r.Form["tpl_password"] tplPublicKeys := r.Form["tpl_public_keys"] users := make(map[string]bool) for idx := range tplUsernames { username := tplUsernames[idx] password := "" publicKey := "" if len(tplPasswords) > idx { password = strings.TrimSpace(tplPasswords[idx]) } if len(tplPublicKeys) > idx { publicKey = strings.TrimSpace(tplPublicKeys[idx]) } if username == "" { continue } if _, ok := users[username]; ok { continue } users[username] = true res = append(res, userTemplateFields{ Username: username, Password: password, PublicKeys: []string{publicKey}, RequirePwdChange: r.Form.Get("tpl_require_password_change") != "", }) } return res } func getVirtualFoldersFromPostFields(r *http.Request) []vfs.VirtualFolder { var virtualFolders []vfs.VirtualFolder folderPaths := r.Form["vfolder_path"] folderNames := r.Form["vfolder_name"] folderQuotaSizes := r.Form["vfolder_quota_size"] folderQuotaFiles := r.Form["vfolder_quota_files"] for idx, p := range folderPaths { name := "" if len(folderNames) > idx { name = folderNames[idx] } if p != "" && name != "" { vfolder := vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: name, }, VirtualPath: p, QuotaFiles: -1, QuotaSize: -1, } if len(folderQuotaSizes) > idx { quotaSize, err := util.ParseBytes(folderQuotaSizes[idx]) if err == nil { vfolder.QuotaSize = quotaSize } } if len(folderQuotaFiles) > idx { quotaFiles, err := strconv.Atoi(folderQuotaFiles[idx]) if err == nil { vfolder.QuotaFiles = quotaFiles } } virtualFolders = append(virtualFolders, vfolder) } } return virtualFolders } func getSubDirPermissionsFromPostFields(r *http.Request) map[string][]string { permissions := make(map[string][]string) for idx, p := range r.Form["sub_perm_path"] { if p != "" { permissions[p] = r.Form["sub_perm_permissions"+strconv.Itoa(idx)] } } return permissions } func getUserPermissionsFromPostFields(r *http.Request) map[string][]string { permissions := getSubDirPermissionsFromPostFields(r) permissions["/"] = r.Form["permissions"] return permissions } func getAccessTimeRestrictionsFromPostFields(r *http.Request) []sdk.TimePeriod { var result []sdk.TimePeriod dayOfWeeks := r.Form["access_time_day_of_week"] starts := r.Form["access_time_start"] ends := r.Form["access_time_end"] for idx, dayOfWeek := range dayOfWeeks { dayOfWeek = strings.TrimSpace(dayOfWeek) start := "" if len(starts) > idx { start = strings.TrimSpace(starts[idx]) } end := "" if len(ends) > idx { end = strings.TrimSpace(ends[idx]) } dayNumber, err := strconv.Atoi(dayOfWeek) if err == nil && start != "" && end != "" { result = append(result, sdk.TimePeriod{ DayOfWeek: dayNumber, From: start, To: end, }) } } return result } func getBandwidthLimitsFromPostFields(r *http.Request) ([]sdk.BandwidthLimit, error) { var result []sdk.BandwidthLimit bwSources := r.Form["bandwidth_limit_sources"] uploadSources := r.Form["upload_bandwidth_source"] downloadSources := r.Form["download_bandwidth_source"] for idx, bwSource := range bwSources { sources := getSliceFromDelimitedValues(bwSource, ",") if len(sources) > 0 { bwLimit := sdk.BandwidthLimit{ Sources: sources, } ul := "" dl := "" if len(uploadSources) > idx { ul = uploadSources[idx] } if len(downloadSources) > idx { dl = downloadSources[idx] } if ul != "" { bandwidthUL, err := strconv.ParseInt(ul, 10, 64) if err != nil { return result, fmt.Errorf("invalid upload_bandwidth_source%v %q: %w", idx, ul, err) } bwLimit.UploadBandwidth = bandwidthUL } if dl != "" { bandwidthDL, err := strconv.ParseInt(dl, 10, 64) if err != nil { return result, fmt.Errorf("invalid download_bandwidth_source%v %q: %w", idx, ul, err) } bwLimit.DownloadBandwidth = bandwidthDL } result = append(result, bwLimit) } } return result, nil } func getPatterDenyPolicyFromString(policy string) int { denyPolicy := sdk.DenyPolicyDefault if policy == "1" { denyPolicy = sdk.DenyPolicyHide } return denyPolicy } func getFilePatternsFromPostField(r *http.Request) []sdk.PatternsFilter { var result []sdk.PatternsFilter patternPaths := r.Form["pattern_path"] patterns := r.Form["patterns"] patternTypes := r.Form["pattern_type"] policies := r.Form["pattern_policy"] allowedPatterns := make(map[string][]string) deniedPatterns := make(map[string][]string) patternPolicies := make(map[string]string) for idx := range patternPaths { p := patternPaths[idx] filters := strings.ReplaceAll(patterns[idx], " ", "") patternType := patternTypes[idx] patternPolicy := policies[idx] if p != "" && filters != "" { if patternType == "allowed" { allowedPatterns[p] = append(allowedPatterns[p], strings.Split(filters, ",")...) } else { deniedPatterns[p] = append(deniedPatterns[p], strings.Split(filters, ",")...) } if patternPolicy != "" && patternPolicy != "0" { patternPolicies[p] = patternPolicy } } } for dirAllowed, allowPatterns := range allowedPatterns { filter := sdk.PatternsFilter{ Path: dirAllowed, AllowedPatterns: allowPatterns, DenyPolicy: getPatterDenyPolicyFromString(patternPolicies[dirAllowed]), } for dirDenied, denPatterns := range deniedPatterns { if dirAllowed == dirDenied { filter.DeniedPatterns = denPatterns break } } result = append(result, filter) } for dirDenied, denPatterns := range deniedPatterns { found := false for _, res := range result { if res.Path == dirDenied { found = true break } } if !found { result = append(result, sdk.PatternsFilter{ Path: dirDenied, DeniedPatterns: denPatterns, DenyPolicy: getPatterDenyPolicyFromString(patternPolicies[dirDenied]), }) } } return result } func getGroupsFromUserPostFields(r *http.Request) []sdk.GroupMapping { var groups []sdk.GroupMapping primaryGroup := strings.TrimSpace(r.Form.Get("primary_group")) if primaryGroup != "" { groups = append(groups, sdk.GroupMapping{ Name: primaryGroup, Type: sdk.GroupTypePrimary, }) } secondaryGroups := r.Form["secondary_groups"] for _, name := range secondaryGroups { groups = append(groups, sdk.GroupMapping{ Name: strings.TrimSpace(name), Type: sdk.GroupTypeSecondary, }) } membershipGroups := r.Form["membership_groups"] for _, name := range membershipGroups { groups = append(groups, sdk.GroupMapping{ Name: strings.TrimSpace(name), Type: sdk.GroupTypeMembership, }) } return groups } func getFiltersFromUserPostFields(r *http.Request) (sdk.BaseUserFilters, error) { var filters sdk.BaseUserFilters bwLimits, err := getBandwidthLimitsFromPostFields(r) if err != nil { return filters, err } maxFileSize, err := util.ParseBytes(r.Form.Get("max_upload_file_size")) if err != nil { return filters, util.NewI18nError(fmt.Errorf("invalid max upload file size: %w", err), util.I18nErrorInvalidMaxFilesize) } defaultSharesExpiration, err := strconv.Atoi(r.Form.Get("default_shares_expiration")) if err != nil { return filters, fmt.Errorf("invalid default shares expiration: %w", err) } maxSharesExpiration, err := strconv.Atoi(r.Form.Get("max_shares_expiration")) if err != nil { return filters, fmt.Errorf("invalid max shares expiration: %w", err) } passwordExpiration, err := strconv.Atoi(r.Form.Get("password_expiration")) if err != nil { return filters, fmt.Errorf("invalid password expiration: %w", err) } passwordStrength, err := strconv.Atoi(r.Form.Get("password_strength")) if err != nil { return filters, fmt.Errorf("invalid password strength: %w", err) } if r.Form.Get("ftp_security") == "1" { filters.FTPSecurity = 1 } filters.BandwidthLimits = bwLimits filters.AllowedIP = getSliceFromDelimitedValues(r.Form.Get("allowed_ip"), ",") filters.DeniedIP = getSliceFromDelimitedValues(r.Form.Get("denied_ip"), ",") filters.DeniedLoginMethods = r.Form["denied_login_methods"] filters.DeniedProtocols = r.Form["denied_protocols"] filters.TwoFactorAuthProtocols = r.Form["required_two_factor_protocols"] filters.FilePatterns = getFilePatternsFromPostField(r) filters.TLSUsername = sdk.TLSUsername(strings.TrimSpace(r.Form.Get("tls_username"))) filters.WebClient = r.Form["web_client_options"] filters.DefaultSharesExpiration = defaultSharesExpiration filters.MaxSharesExpiration = maxSharesExpiration filters.PasswordExpiration = passwordExpiration filters.PasswordStrength = passwordStrength filters.AccessTime = getAccessTimeRestrictionsFromPostFields(r) hooks := r.Form["hooks"] if slices.Contains(hooks, "external_auth_disabled") { filters.Hooks.ExternalAuthDisabled = true } if slices.Contains(hooks, "pre_login_disabled") { filters.Hooks.PreLoginDisabled = true } if slices.Contains(hooks, "check_password_disabled") { filters.Hooks.CheckPasswordDisabled = true } filters.IsAnonymous = r.Form.Get("is_anonymous") != "" filters.DisableFsChecks = r.Form.Get("disable_fs_checks") != "" filters.AllowAPIKeyAuth = r.Form.Get("allow_api_key_auth") != "" filters.StartDirectory = strings.TrimSpace(r.Form.Get("start_directory")) filters.MaxUploadFileSize = maxFileSize filters.ExternalAuthCacheTime, err = strconv.ParseInt(r.Form.Get("external_auth_cache_time"), 10, 64) if err != nil { return filters, fmt.Errorf("invalid external auth cache time: %w", err) } return filters, nil } func getSecretFromFormField(r *http.Request, field string) *kms.Secret { secret := kms.NewPlainSecret(r.Form.Get(field)) if strings.TrimSpace(secret.GetPayload()) == redactedSecret { secret.SetStatus(sdkkms.SecretStatusRedacted) } if strings.TrimSpace(secret.GetPayload()) == "" { secret.SetStatus("") } return secret } func getS3Config(r *http.Request) (vfs.S3FsConfig, error) { var err error config := vfs.S3FsConfig{} config.Bucket = strings.TrimSpace(r.Form.Get("s3_bucket")) config.Region = strings.TrimSpace(r.Form.Get("s3_region")) config.AccessKey = strings.TrimSpace(r.Form.Get("s3_access_key")) config.RoleARN = strings.TrimSpace(r.Form.Get("s3_role_arn")) config.AccessSecret = getSecretFromFormField(r, "s3_access_secret") config.SSECustomerKey = getSecretFromFormField(r, "s3_sse_customer_key") config.Endpoint = strings.TrimSpace(r.Form.Get("s3_endpoint")) config.StorageClass = strings.TrimSpace(r.Form.Get("s3_storage_class")) config.ACL = strings.TrimSpace(r.Form.Get("s3_acl")) config.KeyPrefix = strings.TrimSpace(strings.TrimPrefix(r.Form.Get("s3_key_prefix"), "/")) config.UploadPartSize, err = strconv.ParseInt(r.Form.Get("s3_upload_part_size"), 10, 64) if err != nil { return config, fmt.Errorf("invalid s3 upload part size: %w", err) } config.UploadConcurrency, err = strconv.Atoi(r.Form.Get("s3_upload_concurrency")) if err != nil { return config, fmt.Errorf("invalid s3 upload concurrency: %w", err) } config.DownloadPartSize, err = strconv.ParseInt(r.Form.Get("s3_download_part_size"), 10, 64) if err != nil { return config, fmt.Errorf("invalid s3 download part size: %w", err) } config.DownloadConcurrency, err = strconv.Atoi(r.Form.Get("s3_download_concurrency")) if err != nil { return config, fmt.Errorf("invalid s3 download concurrency: %w", err) } config.ForcePathStyle = r.Form.Get("s3_force_path_style") != "" config.SkipTLSVerify = r.Form.Get("s3_skip_tls_verify") != "" config.DownloadPartMaxTime, err = strconv.Atoi(r.Form.Get("s3_download_part_max_time")) if err != nil { return config, fmt.Errorf("invalid s3 download part max time: %w", err) } config.UploadPartMaxTime, err = strconv.Atoi(r.Form.Get("s3_upload_part_max_time")) if err != nil { return config, fmt.Errorf("invalid s3 upload part max time: %w", err) } return config, nil } func getGCSConfig(r *http.Request) (vfs.GCSFsConfig, error) { var err error config := vfs.GCSFsConfig{} config.Bucket = strings.TrimSpace(r.Form.Get("gcs_bucket")) config.StorageClass = strings.TrimSpace(r.Form.Get("gcs_storage_class")) config.ACL = strings.TrimSpace(r.Form.Get("gcs_acl")) config.KeyPrefix = strings.TrimSpace(strings.TrimPrefix(r.Form.Get("gcs_key_prefix"), "/")) uploadPartSize, err := strconv.ParseInt(r.Form.Get("gcs_upload_part_size"), 10, 64) if err == nil { config.UploadPartSize = uploadPartSize } uploadPartMaxTime, err := strconv.Atoi(r.Form.Get("gcs_upload_part_max_time")) if err == nil { config.UploadPartMaxTime = uploadPartMaxTime } autoCredentials := r.Form.Get("gcs_auto_credentials") if autoCredentials != "" { config.AutomaticCredentials = 1 } else { config.AutomaticCredentials = 0 } credentials, _, err := r.FormFile("gcs_credential_file") if errors.Is(err, http.ErrMissingFile) { return config, nil } if err != nil { return config, err } defer credentials.Close() fileBytes, err := io.ReadAll(credentials) if err != nil || len(fileBytes) == 0 { if len(fileBytes) == 0 { err = errors.New("credentials file size must be greater than 0") } return config, err } config.Credentials = kms.NewPlainSecret(util.BytesToString(fileBytes)) config.AutomaticCredentials = 0 return config, err } func getSFTPConfig(r *http.Request) (vfs.SFTPFsConfig, error) { var err error config := vfs.SFTPFsConfig{} config.Endpoint = strings.TrimSpace(r.Form.Get("sftp_endpoint")) config.Username = strings.TrimSpace(r.Form.Get("sftp_username")) config.Password = getSecretFromFormField(r, "sftp_password") config.PrivateKey = getSecretFromFormField(r, "sftp_private_key") config.KeyPassphrase = getSecretFromFormField(r, "sftp_key_passphrase") fingerprintsFormValue := r.Form.Get("sftp_fingerprints") config.Fingerprints = getSliceFromDelimitedValues(fingerprintsFormValue, "\n") config.Prefix = strings.TrimSpace(r.Form.Get("sftp_prefix")) config.DisableCouncurrentReads = r.Form.Get("sftp_disable_concurrent_reads") != "" config.BufferSize, err = strconv.ParseInt(r.Form.Get("sftp_buffer_size"), 10, 64) if r.Form.Get("sftp_equality_check_mode") != "" { config.EqualityCheckMode = 1 } else { config.EqualityCheckMode = 0 } if err != nil { return config, fmt.Errorf("invalid SFTP buffer size: %w", err) } return config, nil } func getHTTPFsConfig(r *http.Request) vfs.HTTPFsConfig { config := vfs.HTTPFsConfig{} config.Endpoint = strings.TrimSpace(r.Form.Get("http_endpoint")) config.Username = strings.TrimSpace(r.Form.Get("http_username")) config.SkipTLSVerify = r.Form.Get("http_skip_tls_verify") != "" config.Password = getSecretFromFormField(r, "http_password") config.APIKey = getSecretFromFormField(r, "http_api_key") if r.Form.Get("http_equality_check_mode") != "" { config.EqualityCheckMode = 1 } else { config.EqualityCheckMode = 0 } return config } func getAzureConfig(r *http.Request) (vfs.AzBlobFsConfig, error) { var err error config := vfs.AzBlobFsConfig{} config.Container = strings.TrimSpace(r.Form.Get("az_container")) config.AccountName = strings.TrimSpace(r.Form.Get("az_account_name")) config.AccountKey = getSecretFromFormField(r, "az_account_key") config.SASURL = getSecretFromFormField(r, "az_sas_url") config.Endpoint = strings.TrimSpace(r.Form.Get("az_endpoint")) config.KeyPrefix = strings.TrimSpace(strings.TrimPrefix(r.Form.Get("az_key_prefix"), "/")) config.AccessTier = strings.TrimSpace(r.Form.Get("az_access_tier")) config.UseEmulator = r.Form.Get("az_use_emulator") != "" config.UploadPartSize, err = strconv.ParseInt(r.Form.Get("az_upload_part_size"), 10, 64) if err != nil { return config, fmt.Errorf("invalid azure upload part size: %w", err) } config.UploadConcurrency, err = strconv.Atoi(r.Form.Get("az_upload_concurrency")) if err != nil { return config, fmt.Errorf("invalid azure upload concurrency: %w", err) } config.DownloadPartSize, err = strconv.ParseInt(r.Form.Get("az_download_part_size"), 10, 64) if err != nil { return config, fmt.Errorf("invalid azure download part size: %w", err) } config.DownloadConcurrency, err = strconv.Atoi(r.Form.Get("az_download_concurrency")) if err != nil { return config, fmt.Errorf("invalid azure download concurrency: %w", err) } return config, nil } func getOsConfigFromPostFields(r *http.Request, readBufferField, writeBufferField string) sdk.OSFsConfig { config := sdk.OSFsConfig{} readBuffer, err := strconv.Atoi(r.Form.Get(readBufferField)) if err == nil { config.ReadBufferSize = readBuffer } writeBuffer, err := strconv.Atoi(r.Form.Get(writeBufferField)) if err == nil { config.WriteBufferSize = writeBuffer } return config } func getFsConfigFromPostFields(r *http.Request) (vfs.Filesystem, error) { var fs vfs.Filesystem fs.Provider = dataprovider.GetProviderFromValue(r.Form.Get("fs_provider")) switch fs.Provider { case sdk.LocalFilesystemProvider: fs.OSConfig = getOsConfigFromPostFields(r, "osfs_read_buffer_size", "osfs_write_buffer_size") case sdk.S3FilesystemProvider: config, err := getS3Config(r) if err != nil { return fs, err } fs.S3Config = config case sdk.AzureBlobFilesystemProvider: config, err := getAzureConfig(r) if err != nil { return fs, err } fs.AzBlobConfig = config case sdk.GCSFilesystemProvider: config, err := getGCSConfig(r) if err != nil { return fs, err } fs.GCSConfig = config case sdk.CryptedFilesystemProvider: fs.CryptConfig.Passphrase = getSecretFromFormField(r, "crypt_passphrase") fs.CryptConfig.OSFsConfig = getOsConfigFromPostFields(r, "cryptfs_read_buffer_size", "cryptfs_write_buffer_size") case sdk.SFTPFilesystemProvider: config, err := getSFTPConfig(r) if err != nil { return fs, err } fs.SFTPConfig = config case sdk.HTTPFilesystemProvider: fs.HTTPConfig = getHTTPFsConfig(r) } return fs, nil } func getAdminHiddenUserPageSections(r *http.Request) int { var result int for _, val := range r.Form["user_page_hidden_sections"] { switch val { case "1": result++ case "2": result += 2 case "3": result += 4 case "4": result += 8 case "5": result += 16 case "6": result += 32 case "7": result += 64 } } return result } func getAdminFromPostFields(r *http.Request) (dataprovider.Admin, error) { var admin dataprovider.Admin err := r.ParseForm() if err != nil { return admin, util.NewI18nError(err, util.I18nErrorInvalidForm) } status, err := strconv.Atoi(r.Form.Get("status")) if err != nil { return admin, fmt.Errorf("invalid status: %w", err) } admin.Username = strings.TrimSpace(r.Form.Get("username")) admin.Password = strings.TrimSpace(r.Form.Get("password")) admin.Permissions = r.Form["permissions"] admin.Email = strings.TrimSpace(r.Form.Get("email")) admin.Status = status admin.Role = strings.TrimSpace(r.Form.Get("role")) admin.Filters.AllowList = getSliceFromDelimitedValues(r.Form.Get("allowed_ip"), ",") admin.Filters.AllowAPIKeyAuth = r.Form.Get("allow_api_key_auth") != "" admin.Filters.RequireTwoFactor = r.Form.Get("require_two_factor") != "" admin.Filters.RequirePasswordChange = r.Form.Get("require_password_change") != "" admin.AdditionalInfo = r.Form.Get("additional_info") admin.Description = r.Form.Get("description") admin.Filters.Preferences.HideUserPageSections = getAdminHiddenUserPageSections(r) admin.Filters.Preferences.DefaultUsersExpiration = 0 if val := r.Form.Get("default_users_expiration"); val != "" { defaultUsersExpiration, err := strconv.Atoi(r.Form.Get("default_users_expiration")) if err != nil { return admin, fmt.Errorf("invalid default users expiration: %w", err) } admin.Filters.Preferences.DefaultUsersExpiration = defaultUsersExpiration } for k := range r.Form { if hasPrefixAndSuffix(k, "groups[", "][group]") { groupName := strings.TrimSpace(r.Form.Get(k)) if groupName != "" { group := dataprovider.AdminGroupMapping{ Name: groupName, } base, _ := strings.CutSuffix(k, "[group]") addAsGroupType := strings.TrimSpace(r.Form.Get(base + "[group_type]")) switch addAsGroupType { case "1": group.Options.AddToUsersAs = dataprovider.GroupAddToUsersAsPrimary case "2": group.Options.AddToUsersAs = dataprovider.GroupAddToUsersAsSecondary default: group.Options.AddToUsersAs = dataprovider.GroupAddToUsersAsMembership } admin.Groups = append(admin.Groups, group) } } } return admin, nil } func replacePlaceholders(field string, replacements map[string]string) string { for k, v := range replacements { field = strings.ReplaceAll(field, k, v) } return field } func getFolderFromTemplate(folder vfs.BaseVirtualFolder, name string) vfs.BaseVirtualFolder { folder.Name = name replacements := make(map[string]string) replacements["%name%"] = folder.Name folder.MappedPath = replacePlaceholders(folder.MappedPath, replacements) folder.Description = replacePlaceholders(folder.Description, replacements) switch folder.FsConfig.Provider { case sdk.CryptedFilesystemProvider: folder.FsConfig.CryptConfig = getCryptFsFromTemplate(folder.FsConfig.CryptConfig, replacements) case sdk.S3FilesystemProvider: folder.FsConfig.S3Config = getS3FsFromTemplate(folder.FsConfig.S3Config, replacements) case sdk.GCSFilesystemProvider: folder.FsConfig.GCSConfig = getGCSFsFromTemplate(folder.FsConfig.GCSConfig, replacements) case sdk.AzureBlobFilesystemProvider: folder.FsConfig.AzBlobConfig = getAzBlobFsFromTemplate(folder.FsConfig.AzBlobConfig, replacements) case sdk.SFTPFilesystemProvider: folder.FsConfig.SFTPConfig = getSFTPFsFromTemplate(folder.FsConfig.SFTPConfig, replacements) case sdk.HTTPFilesystemProvider: folder.FsConfig.HTTPConfig = getHTTPFsFromTemplate(folder.FsConfig.HTTPConfig, replacements) } return folder } func getCryptFsFromTemplate(fsConfig vfs.CryptFsConfig, replacements map[string]string) vfs.CryptFsConfig { if fsConfig.Passphrase != nil { if fsConfig.Passphrase.IsPlain() { payload := replacePlaceholders(fsConfig.Passphrase.GetPayload(), replacements) fsConfig.Passphrase = kms.NewPlainSecret(payload) } } return fsConfig } func getS3FsFromTemplate(fsConfig vfs.S3FsConfig, replacements map[string]string) vfs.S3FsConfig { fsConfig.KeyPrefix = replacePlaceholders(fsConfig.KeyPrefix, replacements) fsConfig.AccessKey = replacePlaceholders(fsConfig.AccessKey, replacements) if fsConfig.AccessSecret != nil && fsConfig.AccessSecret.IsPlain() { payload := replacePlaceholders(fsConfig.AccessSecret.GetPayload(), replacements) fsConfig.AccessSecret = kms.NewPlainSecret(payload) } if fsConfig.SSECustomerKey != nil && fsConfig.SSECustomerKey.IsPlain() { payload := replacePlaceholders(fsConfig.SSECustomerKey.GetPayload(), replacements) fsConfig.SSECustomerKey = kms.NewPlainSecret(payload) } return fsConfig } func getGCSFsFromTemplate(fsConfig vfs.GCSFsConfig, replacements map[string]string) vfs.GCSFsConfig { fsConfig.KeyPrefix = replacePlaceholders(fsConfig.KeyPrefix, replacements) return fsConfig } func getAzBlobFsFromTemplate(fsConfig vfs.AzBlobFsConfig, replacements map[string]string) vfs.AzBlobFsConfig { fsConfig.KeyPrefix = replacePlaceholders(fsConfig.KeyPrefix, replacements) fsConfig.AccountName = replacePlaceholders(fsConfig.AccountName, replacements) if fsConfig.AccountKey != nil && fsConfig.AccountKey.IsPlain() { payload := replacePlaceholders(fsConfig.AccountKey.GetPayload(), replacements) fsConfig.AccountKey = kms.NewPlainSecret(payload) } return fsConfig } func getSFTPFsFromTemplate(fsConfig vfs.SFTPFsConfig, replacements map[string]string) vfs.SFTPFsConfig { fsConfig.Prefix = replacePlaceholders(fsConfig.Prefix, replacements) fsConfig.Username = replacePlaceholders(fsConfig.Username, replacements) if fsConfig.Password != nil && fsConfig.Password.IsPlain() { payload := replacePlaceholders(fsConfig.Password.GetPayload(), replacements) fsConfig.Password = kms.NewPlainSecret(payload) } return fsConfig } func getHTTPFsFromTemplate(fsConfig vfs.HTTPFsConfig, replacements map[string]string) vfs.HTTPFsConfig { fsConfig.Username = replacePlaceholders(fsConfig.Username, replacements) return fsConfig } func getUserFromTemplate(user dataprovider.User, template userTemplateFields) dataprovider.User { user.Username = template.Username user.Password = template.Password user.PublicKeys = template.PublicKeys user.Filters.RequirePasswordChange = template.RequirePwdChange replacements := make(map[string]string) replacements["%username%"] = user.Username if user.Password != "" && !user.IsPasswordHashed() { user.Password = replacePlaceholders(user.Password, replacements) replacements["%password%"] = user.Password } user.HomeDir = replacePlaceholders(user.HomeDir, replacements) var vfolders []vfs.VirtualFolder for _, vfolder := range user.VirtualFolders { vfolder.Name = replacePlaceholders(vfolder.Name, replacements) vfolder.VirtualPath = replacePlaceholders(vfolder.VirtualPath, replacements) vfolders = append(vfolders, vfolder) } user.VirtualFolders = vfolders user.Description = replacePlaceholders(user.Description, replacements) user.AdditionalInfo = replacePlaceholders(user.AdditionalInfo, replacements) user.Filters.StartDirectory = replacePlaceholders(user.Filters.StartDirectory, replacements) switch user.FsConfig.Provider { case sdk.CryptedFilesystemProvider: user.FsConfig.CryptConfig = getCryptFsFromTemplate(user.FsConfig.CryptConfig, replacements) case sdk.S3FilesystemProvider: user.FsConfig.S3Config = getS3FsFromTemplate(user.FsConfig.S3Config, replacements) case sdk.GCSFilesystemProvider: user.FsConfig.GCSConfig = getGCSFsFromTemplate(user.FsConfig.GCSConfig, replacements) case sdk.AzureBlobFilesystemProvider: user.FsConfig.AzBlobConfig = getAzBlobFsFromTemplate(user.FsConfig.AzBlobConfig, replacements) case sdk.SFTPFilesystemProvider: user.FsConfig.SFTPConfig = getSFTPFsFromTemplate(user.FsConfig.SFTPConfig, replacements) case sdk.HTTPFilesystemProvider: user.FsConfig.HTTPConfig = getHTTPFsFromTemplate(user.FsConfig.HTTPConfig, replacements) } return user } func getTransferLimits(r *http.Request) (int64, int64, int64, error) { dataTransferUL, err := strconv.ParseInt(r.Form.Get("upload_data_transfer"), 10, 64) if err != nil { return 0, 0, 0, fmt.Errorf("invalid upload data transfer: %w", err) } dataTransferDL, err := strconv.ParseInt(r.Form.Get("download_data_transfer"), 10, 64) if err != nil { return 0, 0, 0, fmt.Errorf("invalid download data transfer: %w", err) } dataTransferTotal, err := strconv.ParseInt(r.Form.Get("total_data_transfer"), 10, 64) if err != nil { return 0, 0, 0, fmt.Errorf("invalid total data transfer: %w", err) } return dataTransferUL, dataTransferDL, dataTransferTotal, nil } func getQuotaLimits(r *http.Request) (int64, int, error) { quotaSize, err := util.ParseBytes(r.Form.Get("quota_size")) if err != nil { return 0, 0, util.NewI18nError(fmt.Errorf("invalid quota size: %w", err), util.I18nErrorInvalidQuotaSize) } quotaFiles, err := strconv.Atoi(r.Form.Get("quota_files")) if err != nil { return 0, 0, fmt.Errorf("invalid quota files: %w", err) } return quotaSize, quotaFiles, nil } func updateRepeaterFormFields(r *http.Request) { for k := range r.Form { if hasPrefixAndSuffix(k, "public_keys[", "][public_key]") { key := r.Form.Get(k) if strings.TrimSpace(key) != "" { r.Form.Add("public_keys", key) } continue } if hasPrefixAndSuffix(k, "tls_certs[", "][tls_cert]") { cert := strings.TrimSpace(r.Form.Get(k)) if cert != "" { r.Form.Add("tls_certs", cert) } continue } if hasPrefixAndSuffix(k, "additional_emails[", "][additional_email]") { email := strings.TrimSpace(r.Form.Get(k)) if email != "" { r.Form.Add("additional_emails", email) } continue } if hasPrefixAndSuffix(k, "virtual_folders[", "][vfolder_path]") { base, _ := strings.CutSuffix(k, "[vfolder_path]") r.Form.Add("vfolder_path", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("vfolder_name", strings.TrimSpace(r.Form.Get(base+"[vfolder_name]"))) r.Form.Add("vfolder_quota_files", strings.TrimSpace(r.Form.Get(base+"[vfolder_quota_files]"))) r.Form.Add("vfolder_quota_size", strings.TrimSpace(r.Form.Get(base+"[vfolder_quota_size]"))) continue } if hasPrefixAndSuffix(k, "directory_permissions[", "][sub_perm_path]") { base, _ := strings.CutSuffix(k, "[sub_perm_path]") r.Form.Add("sub_perm_path", strings.TrimSpace(r.Form.Get(k))) r.Form["sub_perm_permissions"+strconv.Itoa(len(r.Form["sub_perm_path"])-1)] = r.Form[base+"[sub_perm_permissions][]"] continue } if hasPrefixAndSuffix(k, "directory_patterns[", "][pattern_path]") { base, _ := strings.CutSuffix(k, "[pattern_path]") r.Form.Add("pattern_path", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("patterns", strings.TrimSpace(r.Form.Get(base+"[patterns]"))) r.Form.Add("pattern_type", strings.TrimSpace(r.Form.Get(base+"[pattern_type]"))) r.Form.Add("pattern_policy", strings.TrimSpace(r.Form.Get(base+"[pattern_policy]"))) continue } if hasPrefixAndSuffix(k, "access_time_restrictions[", "][access_time_day_of_week]") { base, _ := strings.CutSuffix(k, "[access_time_day_of_week]") r.Form.Add("access_time_day_of_week", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("access_time_start", strings.TrimSpace(r.Form.Get(base+"[access_time_start]"))) r.Form.Add("access_time_end", strings.TrimSpace(r.Form.Get(base+"[access_time_end]"))) continue } if hasPrefixAndSuffix(k, "src_bandwidth_limits[", "][bandwidth_limit_sources]") { base, _ := strings.CutSuffix(k, "[bandwidth_limit_sources]") r.Form.Add("bandwidth_limit_sources", r.Form.Get(k)) r.Form.Add("upload_bandwidth_source", strings.TrimSpace(r.Form.Get(base+"[upload_bandwidth_source]"))) r.Form.Add("download_bandwidth_source", strings.TrimSpace(r.Form.Get(base+"[download_bandwidth_source]"))) continue } if hasPrefixAndSuffix(k, "template_users[", "][tpl_username]") { base, _ := strings.CutSuffix(k, "[tpl_username]") r.Form.Add("tpl_username", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("tpl_password", strings.TrimSpace(r.Form.Get(base+"[tpl_password]"))) r.Form.Add("tpl_public_keys", strings.TrimSpace(r.Form.Get(base+"[tpl_public_keys]"))) continue } } } func getUserFromPostFields(r *http.Request) (dataprovider.User, error) { user := dataprovider.User{} err := r.ParseMultipartForm(maxRequestSize) if err != nil { return user, util.NewI18nError(err, util.I18nErrorInvalidForm) } defer r.MultipartForm.RemoveAll() //nolint:errcheck updateRepeaterFormFields(r) uid, err := strconv.Atoi(r.Form.Get("uid")) if err != nil { return user, fmt.Errorf("invalid uid: %w", err) } gid, err := strconv.Atoi(r.Form.Get("gid")) if err != nil { return user, fmt.Errorf("invalid uid: %w", err) } maxSessions, err := strconv.Atoi(r.Form.Get("max_sessions")) if err != nil { return user, fmt.Errorf("invalid max sessions: %w", err) } quotaSize, quotaFiles, err := getQuotaLimits(r) if err != nil { return user, err } bandwidthUL, err := strconv.ParseInt(r.Form.Get("upload_bandwidth"), 10, 64) if err != nil { return user, fmt.Errorf("invalid upload bandwidth: %w", err) } bandwidthDL, err := strconv.ParseInt(r.Form.Get("download_bandwidth"), 10, 64) if err != nil { return user, fmt.Errorf("invalid download bandwidth: %w", err) } dataTransferUL, dataTransferDL, dataTransferTotal, err := getTransferLimits(r) if err != nil { return user, err } status, err := strconv.Atoi(r.Form.Get("status")) if err != nil { return user, fmt.Errorf("invalid status: %w", err) } expirationDateMillis := int64(0) expirationDateString := r.Form.Get("expiration_date") if strings.TrimSpace(expirationDateString) != "" { expirationDate, err := time.Parse(webDateTimeFormat, expirationDateString) if err != nil { return user, err } expirationDateMillis = util.GetTimeAsMsSinceEpoch(expirationDate) } fsConfig, err := getFsConfigFromPostFields(r) if err != nil { return user, err } filters, err := getFiltersFromUserPostFields(r) if err != nil { return user, err } filters.TLSCerts = r.Form["tls_certs"] user = dataprovider.User{ BaseUser: sdk.BaseUser{ Username: strings.TrimSpace(r.Form.Get("username")), Email: strings.TrimSpace(r.Form.Get("email")), Password: strings.TrimSpace(r.Form.Get("password")), PublicKeys: r.Form["public_keys"], HomeDir: strings.TrimSpace(r.Form.Get("home_dir")), UID: uid, GID: gid, Permissions: getUserPermissionsFromPostFields(r), MaxSessions: maxSessions, QuotaSize: quotaSize, QuotaFiles: quotaFiles, UploadBandwidth: bandwidthUL, DownloadBandwidth: bandwidthDL, UploadDataTransfer: dataTransferUL, DownloadDataTransfer: dataTransferDL, TotalDataTransfer: dataTransferTotal, Status: status, ExpirationDate: expirationDateMillis, AdditionalInfo: r.Form.Get("additional_info"), Description: r.Form.Get("description"), Role: strings.TrimSpace(r.Form.Get("role")), }, Filters: dataprovider.UserFilters{ BaseUserFilters: filters, RequirePasswordChange: r.Form.Get("require_password_change") != "", AdditionalEmails: r.Form["additional_emails"], }, VirtualFolders: getVirtualFoldersFromPostFields(r), FsConfig: fsConfig, Groups: getGroupsFromUserPostFields(r), } return user, nil } func getGroupFromPostFields(r *http.Request) (dataprovider.Group, error) { group := dataprovider.Group{} err := r.ParseMultipartForm(maxRequestSize) if err != nil { return group, util.NewI18nError(err, util.I18nErrorInvalidForm) } defer r.MultipartForm.RemoveAll() //nolint:errcheck updateRepeaterFormFields(r) maxSessions, err := strconv.Atoi(r.Form.Get("max_sessions")) if err != nil { return group, fmt.Errorf("invalid max sessions: %w", err) } quotaSize, quotaFiles, err := getQuotaLimits(r) if err != nil { return group, err } bandwidthUL, err := strconv.ParseInt(r.Form.Get("upload_bandwidth"), 10, 64) if err != nil { return group, fmt.Errorf("invalid upload bandwidth: %w", err) } bandwidthDL, err := strconv.ParseInt(r.Form.Get("download_bandwidth"), 10, 64) if err != nil { return group, fmt.Errorf("invalid download bandwidth: %w", err) } dataTransferUL, dataTransferDL, dataTransferTotal, err := getTransferLimits(r) if err != nil { return group, err } expiresIn, err := strconv.Atoi(r.Form.Get("expires_in")) if err != nil { return group, fmt.Errorf("invalid expires in: %w", err) } fsConfig, err := getFsConfigFromPostFields(r) if err != nil { return group, err } filters, err := getFiltersFromUserPostFields(r) if err != nil { return group, err } group = dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: strings.TrimSpace(r.Form.Get("name")), Description: r.Form.Get("description"), }, UserSettings: dataprovider.GroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{ HomeDir: strings.TrimSpace(r.Form.Get("home_dir")), MaxSessions: maxSessions, QuotaSize: quotaSize, QuotaFiles: quotaFiles, Permissions: getSubDirPermissionsFromPostFields(r), UploadBandwidth: bandwidthUL, DownloadBandwidth: bandwidthDL, UploadDataTransfer: dataTransferUL, DownloadDataTransfer: dataTransferDL, TotalDataTransfer: dataTransferTotal, ExpiresIn: expiresIn, Filters: filters, }, FsConfig: fsConfig, }, VirtualFolders: getVirtualFoldersFromPostFields(r), } return group, nil } func getKeyValsFromPostFields(r *http.Request, key, val string) []dataprovider.KeyValue { var res []dataprovider.KeyValue keys := r.Form[key] values := r.Form[val] for idx, k := range keys { v := values[idx] if k != "" && v != "" { res = append(res, dataprovider.KeyValue{ Key: k, Value: v, }) } } return res } func getRenameConfigsFromPostFields(r *http.Request) []dataprovider.RenameConfig { var res []dataprovider.RenameConfig keys := r.Form["fs_rename_source"] values := r.Form["fs_rename_target"] for idx, k := range keys { v := values[idx] if k != "" && v != "" { opts := r.Form["fs_rename_options"+strconv.Itoa(idx)] res = append(res, dataprovider.RenameConfig{ KeyValue: dataprovider.KeyValue{ Key: k, Value: v, }, UpdateModTime: slices.Contains(opts, "1"), }) } } return res } func getFoldersRetentionFromPostFields(r *http.Request) ([]dataprovider.FolderRetention, error) { var res []dataprovider.FolderRetention paths := r.Form["folder_retention_path"] values := r.Form["folder_retention_val"] for idx, p := range paths { if p != "" { retention, err := strconv.Atoi(values[idx]) if err != nil { return nil, fmt.Errorf("invalid retention for path %q: %w", p, err) } opts := r.Form["folder_retention_options"+strconv.Itoa(idx)] res = append(res, dataprovider.FolderRetention{ Path: p, Retention: retention, DeleteEmptyDirs: slices.Contains(opts, "1"), }) } } return res, nil } func getHTTPPartsFromPostFields(r *http.Request) []dataprovider.HTTPPart { var result []dataprovider.HTTPPart names := r.Form["http_part_name"] files := r.Form["http_part_file"] headers := r.Form["http_part_headers"] bodies := r.Form["http_part_body"] orders := r.Form["http_part_order"] for idx, partName := range names { if partName != "" { order, err := strconv.Atoi(orders[idx]) if err == nil { filePath := files[idx] body := bodies[idx] concatHeaders := getSliceFromDelimitedValues(headers[idx], "\n") var headers []dataprovider.KeyValue for _, h := range concatHeaders { values := strings.SplitN(h, ":", 2) if len(values) > 1 { headers = append(headers, dataprovider.KeyValue{ Key: strings.TrimSpace(values[0]), Value: strings.TrimSpace(values[1]), }) } } result = append(result, dataprovider.HTTPPart{ Name: partName, Filepath: filePath, Headers: headers, Body: body, Order: order, }) } } } sort.Slice(result, func(i, j int) bool { return result[i].Order < result[j].Order }) return result } func updateRepeaterFormActionFields(r *http.Request) { for k := range r.Form { if hasPrefixAndSuffix(k, "http_headers[", "][http_header_key]") { base, _ := strings.CutSuffix(k, "[http_header_key]") r.Form.Add("http_header_key", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("http_header_value", strings.TrimSpace(r.Form.Get(base+"[http_header_value]"))) continue } if hasPrefixAndSuffix(k, "query_parameters[", "][http_query_key]") { base, _ := strings.CutSuffix(k, "[http_query_key]") r.Form.Add("http_query_key", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("http_query_value", strings.TrimSpace(r.Form.Get(base+"[http_query_value]"))) continue } if hasPrefixAndSuffix(k, "multipart_body[", "][http_part_name]") { base, _ := strings.CutSuffix(k, "[http_part_name]") order, _ := strings.CutPrefix(k, "multipart_body[") order, _ = strings.CutSuffix(order, "][http_part_name]") r.Form.Add("http_part_name", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("http_part_file", strings.TrimSpace(r.Form.Get(base+"[http_part_file]"))) r.Form.Add("http_part_headers", strings.TrimSpace(r.Form.Get(base+"[http_part_headers]"))) r.Form.Add("http_part_body", strings.TrimSpace(r.Form.Get(base+"[http_part_body]"))) r.Form.Add("http_part_order", order) continue } if hasPrefixAndSuffix(k, "env_vars[", "][cmd_env_key]") { base, _ := strings.CutSuffix(k, "[cmd_env_key]") r.Form.Add("cmd_env_key", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("cmd_env_value", strings.TrimSpace(r.Form.Get(base+"[cmd_env_value]"))) continue } if hasPrefixAndSuffix(k, "data_retention[", "][folder_retention_path]") { base, _ := strings.CutSuffix(k, "[folder_retention_path]") r.Form.Add("folder_retention_path", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("folder_retention_val", strings.TrimSpace(r.Form.Get(base+"[folder_retention_val]"))) r.Form["folder_retention_options"+strconv.Itoa(len(r.Form["folder_retention_path"])-1)] = r.Form[base+"[folder_retention_options][]"] continue } if hasPrefixAndSuffix(k, "fs_rename[", "][fs_rename_source]") { base, _ := strings.CutSuffix(k, "[fs_rename_source]") r.Form.Add("fs_rename_source", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("fs_rename_target", strings.TrimSpace(r.Form.Get(base+"[fs_rename_target]"))) r.Form["fs_rename_options"+strconv.Itoa(len(r.Form["fs_rename_source"])-1)] = r.Form[base+"[fs_rename_options][]"] continue } if hasPrefixAndSuffix(k, "fs_copy[", "][fs_copy_source]") { base, _ := strings.CutSuffix(k, "[fs_copy_source]") r.Form.Add("fs_copy_source", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("fs_copy_target", strings.TrimSpace(r.Form.Get(base+"[fs_copy_target]"))) continue } } } func getEventActionOptionsFromPostFields(r *http.Request) (dataprovider.BaseEventActionOptions, error) { updateRepeaterFormActionFields(r) httpTimeout, err := strconv.Atoi(r.Form.Get("http_timeout")) if err != nil { return dataprovider.BaseEventActionOptions{}, fmt.Errorf("invalid http timeout: %w", err) } cmdTimeout, err := strconv.Atoi(r.Form.Get("cmd_timeout")) if err != nil { return dataprovider.BaseEventActionOptions{}, fmt.Errorf("invalid command timeout: %w", err) } foldersRetention, err := getFoldersRetentionFromPostFields(r) if err != nil { return dataprovider.BaseEventActionOptions{}, err } fsActionType, err := strconv.Atoi(r.Form.Get("fs_action_type")) if err != nil { return dataprovider.BaseEventActionOptions{}, fmt.Errorf("invalid fs action type: %w", err) } pwdExpirationThreshold, err := strconv.Atoi(r.Form.Get("pwd_expiration_threshold")) if err != nil { return dataprovider.BaseEventActionOptions{}, fmt.Errorf("invalid password expiration threshold: %w", err) } var disableThreshold, deleteThreshold int if val, err := strconv.Atoi(r.Form.Get("inactivity_disable_threshold")); err == nil { disableThreshold = val } if val, err := strconv.Atoi(r.Form.Get("inactivity_delete_threshold")); err == nil { deleteThreshold = val } var emailAttachments []string if r.Form.Get("email_attachments") != "" { emailAttachments = getSliceFromDelimitedValues(r.Form.Get("email_attachments"), ",") } var cmdArgs []string if r.Form.Get("cmd_arguments") != "" { cmdArgs = getSliceFromDelimitedValues(r.Form.Get("cmd_arguments"), ",") } idpMode := 0 if r.Form.Get("idp_mode") == "1" { idpMode = 1 } emailContentType := 0 if r.Form.Get("email_content_type") == "1" { emailContentType = 1 } options := dataprovider.BaseEventActionOptions{ HTTPConfig: dataprovider.EventActionHTTPConfig{ Endpoint: strings.TrimSpace(r.Form.Get("http_endpoint")), Username: strings.TrimSpace(r.Form.Get("http_username")), Password: getSecretFromFormField(r, "http_password"), Headers: getKeyValsFromPostFields(r, "http_header_key", "http_header_value"), Timeout: httpTimeout, SkipTLSVerify: r.Form.Get("http_skip_tls_verify") != "", Method: r.Form.Get("http_method"), QueryParameters: getKeyValsFromPostFields(r, "http_query_key", "http_query_value"), Body: r.Form.Get("http_body"), Parts: getHTTPPartsFromPostFields(r), }, CmdConfig: dataprovider.EventActionCommandConfig{ Cmd: strings.TrimSpace(r.Form.Get("cmd_path")), Args: cmdArgs, Timeout: cmdTimeout, EnvVars: getKeyValsFromPostFields(r, "cmd_env_key", "cmd_env_value"), }, EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: getSliceFromDelimitedValues(r.Form.Get("email_recipients"), ","), Bcc: getSliceFromDelimitedValues(r.Form.Get("email_bcc"), ","), Subject: r.Form.Get("email_subject"), ContentType: emailContentType, Body: r.Form.Get("email_body"), Attachments: emailAttachments, }, RetentionConfig: dataprovider.EventActionDataRetentionConfig{ Folders: foldersRetention, }, FsConfig: dataprovider.EventActionFilesystemConfig{ Type: fsActionType, Renames: getRenameConfigsFromPostFields(r), Deletes: getSliceFromDelimitedValues(r.Form.Get("fs_delete_paths"), ","), MkDirs: getSliceFromDelimitedValues(r.Form.Get("fs_mkdir_paths"), ","), Exist: getSliceFromDelimitedValues(r.Form.Get("fs_exist_paths"), ","), Copy: getKeyValsFromPostFields(r, "fs_copy_source", "fs_copy_target"), Compress: dataprovider.EventActionFsCompress{ Name: strings.TrimSpace(r.Form.Get("fs_compress_name")), Paths: getSliceFromDelimitedValues(r.Form.Get("fs_compress_paths"), ","), }, }, PwdExpirationConfig: dataprovider.EventActionPasswordExpiration{ Threshold: pwdExpirationThreshold, }, UserInactivityConfig: dataprovider.EventActionUserInactivity{ DisableThreshold: disableThreshold, DeleteThreshold: deleteThreshold, }, IDPConfig: dataprovider.EventActionIDPAccountCheck{ Mode: idpMode, TemplateUser: strings.TrimSpace(r.Form.Get("idp_user")), TemplateAdmin: strings.TrimSpace(r.Form.Get("idp_admin")), }, } return options, nil } func getEventActionFromPostFields(r *http.Request) (dataprovider.BaseEventAction, error) { err := r.ParseForm() if err != nil { return dataprovider.BaseEventAction{}, util.NewI18nError(err, util.I18nErrorInvalidForm) } actionType, err := strconv.Atoi(r.Form.Get("type")) if err != nil { return dataprovider.BaseEventAction{}, fmt.Errorf("invalid action type: %w", err) } options, err := getEventActionOptionsFromPostFields(r) if err != nil { return dataprovider.BaseEventAction{}, err } action := dataprovider.BaseEventAction{ Name: strings.TrimSpace(r.Form.Get("name")), Description: r.Form.Get("description"), Type: actionType, Options: options, } return action, nil } func getIDPLoginEventFromPostField(r *http.Request) int { switch r.Form.Get("idp_login_event") { case "1": return 1 case "2": return 2 default: return 0 } } func getEventRuleConditionsFromPostFields(r *http.Request) (dataprovider.EventConditions, error) { var schedules []dataprovider.Schedule var names, groupNames, roleNames, fsPaths []dataprovider.ConditionPattern scheduleHours := r.Form["schedule_hour"] scheduleDayOfWeeks := r.Form["schedule_day_of_week"] scheduleDayOfMonths := r.Form["schedule_day_of_month"] scheduleMonths := r.Form["schedule_month"] for idx, hour := range scheduleHours { if hour != "" { schedules = append(schedules, dataprovider.Schedule{ Hours: hour, DayOfWeek: scheduleDayOfWeeks[idx], DayOfMonth: scheduleDayOfMonths[idx], Month: scheduleMonths[idx], }) } } for idx, name := range r.Form["name_pattern"] { if name != "" { names = append(names, dataprovider.ConditionPattern{ Pattern: name, InverseMatch: r.Form["type_name_pattern"][idx] == inversePatternType, }) } } for idx, name := range r.Form["group_name_pattern"] { if name != "" { groupNames = append(groupNames, dataprovider.ConditionPattern{ Pattern: name, InverseMatch: r.Form["type_group_name_pattern"][idx] == inversePatternType, }) } } for idx, name := range r.Form["role_name_pattern"] { if name != "" { roleNames = append(roleNames, dataprovider.ConditionPattern{ Pattern: name, InverseMatch: r.Form["type_role_name_pattern"][idx] == inversePatternType, }) } } for idx, name := range r.Form["fs_path_pattern"] { if name != "" { fsPaths = append(fsPaths, dataprovider.ConditionPattern{ Pattern: name, InverseMatch: r.Form["type_fs_path_pattern"][idx] == inversePatternType, }) } } minFileSize, err := util.ParseBytes(r.Form.Get("fs_min_size")) if err != nil { return dataprovider.EventConditions{}, util.NewI18nError(fmt.Errorf("invalid min file size: %w", err), util.I18nErrorInvalidMinSize) } maxFileSize, err := util.ParseBytes(r.Form.Get("fs_max_size")) if err != nil { return dataprovider.EventConditions{}, util.NewI18nError(fmt.Errorf("invalid max file size: %w", err), util.I18nErrorInvalidMaxSize) } var eventStatuses []int for _, s := range r.Form["fs_statuses"] { status, err := strconv.ParseInt(s, 10, 32) if err == nil { eventStatuses = append(eventStatuses, int(status)) } } conditions := dataprovider.EventConditions{ FsEvents: r.Form["fs_events"], ProviderEvents: r.Form["provider_events"], IDPLoginEvent: getIDPLoginEventFromPostField(r), Schedules: schedules, Options: dataprovider.ConditionOptions{ Names: names, GroupNames: groupNames, RoleNames: roleNames, FsPaths: fsPaths, Protocols: r.Form["fs_protocols"], EventStatuses: eventStatuses, ProviderObjects: r.Form["provider_objects"], MinFileSize: minFileSize, MaxFileSize: maxFileSize, ConcurrentExecution: r.Form.Get("concurrent_execution") != "", }, } return conditions, nil } func getEventRuleActionsFromPostFields(r *http.Request) []dataprovider.EventAction { var actions []dataprovider.EventAction names := r.Form["action_name"] orders := r.Form["action_order"] for idx, name := range names { if name != "" { order, err := strconv.Atoi(orders[idx]) if err == nil { options := r.Form["action_options"+strconv.Itoa(idx)] actions = append(actions, dataprovider.EventAction{ BaseEventAction: dataprovider.BaseEventAction{ Name: name, }, Order: order + 1, Options: dataprovider.EventActionOptions{ IsFailureAction: slices.Contains(options, "1"), StopOnFailure: slices.Contains(options, "2"), ExecuteSync: slices.Contains(options, "3"), }, }) } } } return actions } func updateRepeaterFormRuleFields(r *http.Request) { for k := range r.Form { if hasPrefixAndSuffix(k, "schedules[", "][schedule_hour]") { base, _ := strings.CutSuffix(k, "[schedule_hour]") r.Form.Add("schedule_hour", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("schedule_day_of_week", strings.TrimSpace(r.Form.Get(base+"[schedule_day_of_week]"))) r.Form.Add("schedule_day_of_month", strings.TrimSpace(r.Form.Get(base+"[schedule_day_of_month]"))) r.Form.Add("schedule_month", strings.TrimSpace(r.Form.Get(base+"[schedule_month]"))) continue } if hasPrefixAndSuffix(k, "name_filters[", "][name_pattern]") { base, _ := strings.CutSuffix(k, "[name_pattern]") r.Form.Add("name_pattern", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("type_name_pattern", strings.TrimSpace(r.Form.Get(base+"[type_name_pattern]"))) continue } if hasPrefixAndSuffix(k, "group_name_filters[", "][group_name_pattern]") { base, _ := strings.CutSuffix(k, "[group_name_pattern]") r.Form.Add("group_name_pattern", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("type_group_name_pattern", strings.TrimSpace(r.Form.Get(base+"[type_group_name_pattern]"))) continue } if hasPrefixAndSuffix(k, "role_name_filters[", "][role_name_pattern]") { base, _ := strings.CutSuffix(k, "[role_name_pattern]") r.Form.Add("role_name_pattern", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("type_role_name_pattern", strings.TrimSpace(r.Form.Get(base+"[type_role_name_pattern]"))) continue } if hasPrefixAndSuffix(k, "path_filters[", "][fs_path_pattern]") { base, _ := strings.CutSuffix(k, "[fs_path_pattern]") r.Form.Add("fs_path_pattern", strings.TrimSpace(r.Form.Get(k))) r.Form.Add("type_fs_path_pattern", strings.TrimSpace(r.Form.Get(base+"[type_fs_path_pattern]"))) continue } if hasPrefixAndSuffix(k, "actions[", "][action_name]") { base, _ := strings.CutSuffix(k, "[action_name]") order, _ := strings.CutPrefix(k, "actions[") order, _ = strings.CutSuffix(order, "][action_name]") r.Form.Add("action_name", strings.TrimSpace(r.Form.Get(k))) r.Form["action_options"+strconv.Itoa(len(r.Form["action_name"])-1)] = r.Form[base+"[action_options][]"] r.Form.Add("action_order", order) continue } } } func getEventRuleFromPostFields(r *http.Request) (dataprovider.EventRule, error) { err := r.ParseForm() if err != nil { return dataprovider.EventRule{}, util.NewI18nError(err, util.I18nErrorInvalidForm) } updateRepeaterFormRuleFields(r) status, err := strconv.Atoi(r.Form.Get("status")) if err != nil { return dataprovider.EventRule{}, fmt.Errorf("invalid status: %w", err) } trigger, err := strconv.Atoi(r.Form.Get("trigger")) if err != nil { return dataprovider.EventRule{}, fmt.Errorf("invalid trigger: %w", err) } conditions, err := getEventRuleConditionsFromPostFields(r) if err != nil { return dataprovider.EventRule{}, err } rule := dataprovider.EventRule{ Name: strings.TrimSpace(r.Form.Get("name")), Status: status, Description: r.Form.Get("description"), Trigger: trigger, Conditions: conditions, Actions: getEventRuleActionsFromPostFields(r), } return rule, nil } func getRoleFromPostFields(r *http.Request) (dataprovider.Role, error) { err := r.ParseForm() if err != nil { return dataprovider.Role{}, util.NewI18nError(err, util.I18nErrorInvalidForm) } return dataprovider.Role{ Name: strings.TrimSpace(r.Form.Get("name")), Description: r.Form.Get("description"), }, nil } func getIPListEntryFromPostFields(r *http.Request, listType dataprovider.IPListType) (dataprovider.IPListEntry, error) { err := r.ParseForm() if err != nil { return dataprovider.IPListEntry{}, util.NewI18nError(err, util.I18nErrorInvalidForm) } var mode int if listType == dataprovider.IPListTypeDefender { mode, err = strconv.Atoi(r.Form.Get("mode")) if err != nil { return dataprovider.IPListEntry{}, fmt.Errorf("invalid mode: %w", err) } } else { mode = 1 } protocols := 0 for _, proto := range r.Form["protocols"] { p, err := strconv.Atoi(proto) if err == nil { protocols += p } } return dataprovider.IPListEntry{ IPOrNet: strings.TrimSpace(r.Form.Get("ipornet")), Mode: mode, Protocols: protocols, Description: r.Form.Get("description"), }, nil } func getSFTPConfigsFromPostFields(r *http.Request) *dataprovider.SFTPDConfigs { return &dataprovider.SFTPDConfigs{ HostKeyAlgos: r.Form["sftp_host_key_algos"], PublicKeyAlgos: r.Form["sftp_pub_key_algos"], KexAlgorithms: r.Form["sftp_kex_algos"], Ciphers: r.Form["sftp_ciphers"], MACs: r.Form["sftp_macs"], } } func getACMEConfigsFromPostFields(r *http.Request) *dataprovider.ACMEConfigs { port, err := strconv.Atoi(r.Form.Get("acme_port")) if err != nil { port = 80 } var protocols int for _, val := range r.Form["acme_protocols"] { switch val { case "1": protocols++ case "2": protocols += 2 case "3": protocols += 4 } } return &dataprovider.ACMEConfigs{ Domain: strings.TrimSpace(r.Form.Get("acme_domain")), Email: strings.TrimSpace(r.Form.Get("acme_email")), HTTP01Challenge: dataprovider.ACMEHTTP01Challenge{Port: port}, Protocols: protocols, } } func getSMTPConfigsFromPostFields(r *http.Request) *dataprovider.SMTPConfigs { port, err := strconv.Atoi(r.Form.Get("smtp_port")) if err != nil { port = 587 } authType, err := strconv.Atoi(r.Form.Get("smtp_auth")) if err != nil { authType = 0 } encryption, err := strconv.Atoi(r.Form.Get("smtp_encryption")) if err != nil { encryption = 0 } debug := 0 if r.Form.Get("smtp_debug") != "" { debug = 1 } oauth2Provider := 0 if r.Form.Get("smtp_oauth2_provider") == "1" { oauth2Provider = 1 } return &dataprovider.SMTPConfigs{ Host: strings.TrimSpace(r.Form.Get("smtp_host")), Port: port, From: strings.TrimSpace(r.Form.Get("smtp_from")), User: strings.TrimSpace(r.Form.Get("smtp_username")), Password: getSecretFromFormField(r, "smtp_password"), AuthType: authType, Encryption: encryption, Domain: strings.TrimSpace(r.Form.Get("smtp_domain")), Debug: debug, OAuth2: dataprovider.SMTPOAuth2{ Provider: oauth2Provider, Tenant: strings.TrimSpace(r.Form.Get("smtp_oauth2_tenant")), ClientID: strings.TrimSpace(r.Form.Get("smtp_oauth2_client_id")), ClientSecret: getSecretFromFormField(r, "smtp_oauth2_client_secret"), RefreshToken: getSecretFromFormField(r, "smtp_oauth2_refresh_token"), }, } } func getImageInputBytes(r *http.Request, fieldName, removeFieldName string, defaultVal []byte) ([]byte, error) { var result []byte remove := r.Form.Get(removeFieldName) if remove == "" || remove == "0" { result = defaultVal } f, _, err := r.FormFile(fieldName) if err != nil { if errors.Is(err, http.ErrMissingFile) { return result, nil } return nil, err } defer f.Close() return io.ReadAll(f) } func getBrandingConfigFromPostFields(r *http.Request, config *dataprovider.BrandingConfigs) ( *dataprovider.BrandingConfigs, error, ) { if config == nil { config = &dataprovider.BrandingConfigs{} } adminLogo, err := getImageInputBytes(r, "branding_webadmin_logo", "branding_webadmin_logo_remove", config.WebAdmin.Logo) if err != nil { return nil, util.NewI18nError(err, util.I18nErrorInvalidForm) } adminFavicon, err := getImageInputBytes(r, "branding_webadmin_favicon", "branding_webadmin_favicon_remove", config.WebAdmin.Favicon) if err != nil { return nil, util.NewI18nError(err, util.I18nErrorInvalidForm) } clientLogo, err := getImageInputBytes(r, "branding_webclient_logo", "branding_webclient_logo_remove", config.WebClient.Logo) if err != nil { return nil, util.NewI18nError(err, util.I18nErrorInvalidForm) } clientFavicon, err := getImageInputBytes(r, "branding_webclient_favicon", "branding_webclient_favicon_remove", config.WebClient.Favicon) if err != nil { return nil, util.NewI18nError(err, util.I18nErrorInvalidForm) } branding := &dataprovider.BrandingConfigs{ WebAdmin: dataprovider.BrandingConfig{ Name: strings.TrimSpace(r.Form.Get("branding_webadmin_name")), ShortName: strings.TrimSpace(r.Form.Get("branding_webadmin_short_name")), Logo: adminLogo, Favicon: adminFavicon, DisclaimerName: strings.TrimSpace(r.Form.Get("branding_webadmin_disclaimer_name")), DisclaimerURL: strings.TrimSpace(r.Form.Get("branding_webadmin_disclaimer_url")), }, WebClient: dataprovider.BrandingConfig{ Name: strings.TrimSpace(r.Form.Get("branding_webclient_name")), ShortName: strings.TrimSpace(r.Form.Get("branding_webclient_short_name")), Logo: clientLogo, Favicon: clientFavicon, DisclaimerName: strings.TrimSpace(r.Form.Get("branding_webclient_disclaimer_name")), DisclaimerURL: strings.TrimSpace(r.Form.Get("branding_webclient_disclaimer_url")), }, } return branding, nil } func (s *httpdServer) handleWebAdminForgotPwd(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) if !smtp.IsEnabled() { s.renderNotFoundPage(w, r, errors.New("this page does not exist")) return } s.renderForgotPwdPage(w, r, nil) } func (s *httpdServer) handleWebAdminForgotPwdPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) err := r.ParseForm() if err != nil { s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } err = handleForgotPassword(r, r.Form.Get("username"), true) if err != nil { s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric)) return } http.Redirect(w, r, webAdminResetPwdPath, http.StatusFound) } func (s *httpdServer) handleWebAdminPasswordReset(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) if !smtp.IsEnabled() { s.renderNotFoundPage(w, r, errors.New("this page does not exist")) return } s.renderResetPwdPage(w, r, nil) } func (s *httpdServer) handleWebAdminTwoFactor(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderTwoFactorPage(w, r, nil) } func (s *httpdServer) handleWebAdminTwoFactorRecovery(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderTwoFactorRecoveryPage(w, r, nil) } func (s *httpdServer) handleWebAdminMFA(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderMFAPage(w, r) } func (s *httpdServer) handleWebAdminProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderProfilePage(w, r, nil) } func (s *httpdServer) handleWebAdminChangePwd(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderChangePasswordPage(w, r, nil) } func (s *httpdServer) handleWebAdminProfilePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) err := r.ParseForm() if err != nil { s.renderProfilePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderProfilePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidToken)) return } admin, err := dataprovider.AdminExists(claims.Username) if err != nil { s.renderProfilePage(w, r, err) return } admin.Filters.AllowAPIKeyAuth = r.Form.Get("allow_api_key_auth") != "" admin.Email = r.Form.Get("email") admin.Description = r.Form.Get("description") err = dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, ipAddr, admin.Role) if err != nil { s.renderProfilePage(w, r, err) return } s.renderMessagePage(w, r, util.I18nProfileTitle, http.StatusOK, nil, util.I18nProfileUpdated) } func (s *httpdServer) handleWebMaintenance(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderMaintenancePage(w, r, nil) } func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, MaxRestoreSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } err = r.ParseMultipartForm(MaxRestoreSize) if err != nil { s.renderMaintenancePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } restoreMode, err := strconv.Atoi(r.Form.Get("mode")) if err != nil { s.renderMaintenancePage(w, r, err) return } scanQuota, err := strconv.Atoi(r.Form.Get("quota")) if err != nil { s.renderMaintenancePage(w, r, err) return } backupFile, _, err := r.FormFile("backup_file") if err != nil { s.renderMaintenancePage(w, r, util.NewI18nError(err, util.I18nErrorBackupFile)) return } defer backupFile.Close() backupContent, err := io.ReadAll(backupFile) if err != nil || len(backupContent) == 0 { if len(backupContent) == 0 { err = errors.New("backup file size must be greater than 0") } s.renderMaintenancePage(w, r, util.NewI18nError(err, util.I18nErrorBackupFile)) return } if err := restoreBackup(backupContent, "", scanQuota, restoreMode, claims.Username, ipAddr, claims.Role); err != nil { s.renderMaintenancePage(w, r, util.NewI18nError(err, util.I18nErrorRestore)) return } s.renderMessagePage(w, r, util.I18nMaintenanceTitle, http.StatusOK, nil, util.I18nBackupOK) } func getAllAdmins(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) return } dataGetter := func(limit, offset int) ([]byte, int, error) { results, err := dataprovider.GetAdmins(limit, offset, dataprovider.OrderASC) if err != nil { return nil, 0, err } data, err := json.Marshal(results) return data, len(results), err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleGetWebAdmins(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, w, r) renderAdminTemplate(w, templateAdmins, data) } func (s *httpdServer) handleWebAdminSetupGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) if dataprovider.HasAdmin() { http.Redirect(w, r, webAdminLoginPath, http.StatusFound) return } s.renderAdminSetupPage(w, r, "", nil) } func (s *httpdServer) handleWebAddAdminGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) admin := &dataprovider.Admin{ Status: 1, Permissions: []string{dataprovider.PermAdminAny}, } s.renderAddUpdateAdminPage(w, r, admin, nil, true) } func (s *httpdServer) handleWebUpdateAdminGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) username := getURLParam(r, "username") admin, err := dataprovider.AdminExists(username) if err == nil { s.renderAddUpdateAdminPage(w, r, &admin, nil, false) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleWebAddAdminPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } admin, err := getAdminFromPostFields(r) if err != nil { s.renderAddUpdateAdminPage(w, r, &admin, err, true) return } if admin.Password == "" { // Administrators can be used with OpenID Connect or for authentication // via API key, in these cases the password is not necessary, we create // a non-usable one. This feature is only useful for WebAdmin, in REST // API you can create an unusable password externally. admin.Password = util.GenerateUniqueID() } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } err = dataprovider.AddAdmin(&admin, claims.Username, ipAddr, claims.Role) if err != nil { s.renderAddUpdateAdminPage(w, r, &admin, err, true) return } http.Redirect(w, r, webAdminsPath, http.StatusSeeOther) } func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) username := getURLParam(r, "username") admin, err := dataprovider.AdminExists(username) if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) return } else if err != nil { s.renderInternalServerErrorPage(w, r, err) return } updatedAdmin, err := getAdminFromPostFields(r) if err != nil { s.renderAddUpdateAdminPage(w, r, &updatedAdmin, err, false) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } updatedAdmin.ID = admin.ID updatedAdmin.Username = admin.Username if updatedAdmin.Password == "" { updatedAdmin.Password = admin.Password } updatedAdmin.Filters.TOTPConfig = admin.Filters.TOTPConfig updatedAdmin.Filters.RecoveryCodes = admin.Filters.RecoveryCodes claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderAddUpdateAdminPage(w, r, &updatedAdmin, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken), false) return } if username == claims.Username { if !util.SlicesEqual(admin.Permissions, updatedAdmin.Permissions) { s.renderAddUpdateAdminPage(w, r, &updatedAdmin, util.NewI18nError(errors.New("you cannot change your permissions"), util.I18nErrorAdminSelfPerms, ), false) return } if updatedAdmin.Status == 0 { s.renderAddUpdateAdminPage(w, r, &updatedAdmin, util.NewI18nError(errors.New("you cannot disable yourself"), util.I18nErrorAdminSelfDisable, ), false) return } if updatedAdmin.Role != claims.Role { s.renderAddUpdateAdminPage(w, r, &updatedAdmin, util.NewI18nError( errors.New("you cannot add/change your role"), util.I18nErrorAdminSelfRole, ), false) return } updatedAdmin.Filters.RequirePasswordChange = admin.Filters.RequirePasswordChange updatedAdmin.Filters.RequireTwoFactor = admin.Filters.RequireTwoFactor } err = dataprovider.UpdateAdmin(&updatedAdmin, claims.Username, ipAddr, claims.Role) if err != nil { s.renderAddUpdateAdminPage(w, r, &updatedAdmin, err, false) return } http.Redirect(w, r, webAdminsPath, http.StatusSeeOther) } func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := defenderHostsPage{ basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, w, r), DefenderHostsURL: webDefenderHostsPath, } renderAdminTemplate(w, templateDefender, data) } func getAllUsers(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) return } dataGetter := func(limit, offset int) ([]byte, int, error) { results, err := dataprovider.GetUsers(limit, offset, dataprovider.OrderASC, claims.Role) if err != nil { return nil, 0, err } data, err := json.Marshal(results) return data, len(results), err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleGetWebUsers(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, w, r) renderAdminTemplate(w, templateUsers, data) } func (s *httpdServer) handleWebTemplateFolderGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) if r.URL.Query().Get("from") != "" { name := r.URL.Query().Get("from") folder, err := dataprovider.GetFolderByName(name) if err == nil { folder.FsConfig.SetEmptySecrets() s.renderFolderPage(w, r, folder, folderPageModeTemplate, nil) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } else { folder := vfs.BaseVirtualFolder{} s.renderFolderPage(w, r, folder, folderPageModeTemplate, nil) } } func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } templateFolder := vfs.BaseVirtualFolder{} err = r.ParseMultipartForm(maxRequestSize) if err != nil { s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, http.StatusBadRequest, util.NewI18nError(err, util.I18nErrorInvalidForm), "") return } defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } templateFolder.MappedPath = r.Form.Get("mapped_path") templateFolder.Description = r.Form.Get("description") fsConfig, err := getFsConfigFromPostFields(r) if err != nil { s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, http.StatusBadRequest, err, "") return } templateFolder.FsConfig = fsConfig var dump dataprovider.BackupData foldersFields := getFoldersForTemplate(r) for _, tmpl := range foldersFields { f := getFolderFromTemplate(templateFolder, tmpl) if err := dataprovider.ValidateFolder(&f); err != nil { s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, http.StatusBadRequest, err, "") return } dump.Folders = append(dump.Folders, f) } if len(dump.Folders) == 0 { s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, http.StatusBadRequest, util.NewI18nError( errors.New("no valid folder defined, unable to complete the requested action"), util.I18nErrorFolderTemplate, ), "") return } if err = RestoreFolders(dump.Folders, "", 1, 0, claims.Username, ipAddr, claims.Role); err != nil { s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, getRespStatus(err), err, "") return } http.Redirect(w, r, webFoldersPath, http.StatusSeeOther) } func (s *httpdServer) handleWebTemplateUserGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) tokenAdmin := getAdminFromToken(r) admin, err := dataprovider.AdminExists(tokenAdmin.Username) if err != nil { s.renderInternalServerErrorPage(w, r, fmt.Errorf("unable to get the admin %q: %w", tokenAdmin.Username, err)) return } if r.URL.Query().Get("from") != "" { username := r.URL.Query().Get("from") user, err := dataprovider.UserExists(username, admin.Role) if err == nil { user.SetEmptySecrets() user.PublicKeys = nil user.Email = "" user.Filters.AdditionalEmails = nil user.Description = "" if user.ExpirationDate == 0 && admin.Filters.Preferences.DefaultUsersExpiration > 0 { user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration))) } s.renderUserPage(w, r, &user, userPageModeTemplate, nil, &admin) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } else { user := dataprovider.User{BaseUser: sdk.BaseUser{ Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }} if admin.Filters.Preferences.DefaultUsersExpiration > 0 { user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration))) } s.renderUserPage(w, r, &user, userPageModeTemplate, nil, &admin) } } func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } templateUser, err := getUserFromPostFields(r) if err != nil { s.renderMessagePage(w, r, util.I18nTemplateUserTitle, http.StatusBadRequest, err, "") return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } var dump dataprovider.BackupData userTmplFields := getUsersForTemplate(r) for _, tmpl := range userTmplFields { u := getUserFromTemplate(templateUser, tmpl) if err := dataprovider.ValidateUser(&u); err != nil { s.renderMessagePage(w, r, util.I18nTemplateUserTitle, http.StatusBadRequest, err, "") return } if claims.Role != "" { u.Role = claims.Role } dump.Users = append(dump.Users, u) } if len(dump.Users) == 0 { s.renderMessagePage(w, r, util.I18nTemplateUserTitle, http.StatusBadRequest, util.NewI18nError( errors.New("no valid user defined, unable to complete the requested action"), util.I18nErrorUserTemplate, ), "") return } if err = RestoreUsers(dump.Users, "", 1, 0, claims.Username, ipAddr, claims.Role); err != nil { s.renderMessagePage(w, r, util.I18nTemplateUserTitle, getRespStatus(err), err, "") return } http.Redirect(w, r, webUsersPath, http.StatusSeeOther) } func (s *httpdServer) handleWebAddUserGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) tokenAdmin := getAdminFromToken(r) admin, err := dataprovider.AdminExists(tokenAdmin.Username) if err != nil { s.renderInternalServerErrorPage(w, r, fmt.Errorf("unable to get the admin %q: %w", tokenAdmin.Username, err)) return } user := dataprovider.User{BaseUser: sdk.BaseUser{ Status: 1, Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }}, } if admin.Filters.Preferences.DefaultUsersExpiration > 0 { user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration))) } s.renderUserPage(w, r, &user, userPageModeAdd, nil, &admin) } func (s *httpdServer) handleWebUpdateUserGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } username := getURLParam(r, "username") user, err := dataprovider.UserExists(username, claims.Role) if err == nil { s.renderUserPage(w, r, &user, userPageModeUpdate, nil, nil) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } user, err := getUserFromPostFields(r) if err != nil { s.renderUserPage(w, r, &user, userPageModeAdd, err, nil) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } user = getUserFromTemplate(user, userTemplateFields{ Username: user.Username, Password: user.Password, PublicKeys: user.PublicKeys, RequirePwdChange: user.Filters.RequirePasswordChange, }) if claims.Role != "" { user.Role = claims.Role } user.Filters.RecoveryCodes = nil user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: false, } err = dataprovider.AddUser(&user, claims.Username, ipAddr, claims.Role) if err != nil { s.renderUserPage(w, r, &user, userPageModeAdd, err, nil) return } http.Redirect(w, r, webUsersPath, http.StatusSeeOther) } func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } username := getURLParam(r, "username") user, err := dataprovider.UserExists(username, claims.Role) if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) return } else if err != nil { s.renderInternalServerErrorPage(w, r, err) return } updatedUser, err := getUserFromPostFields(r) if err != nil { s.renderUserPage(w, r, &user, userPageModeUpdate, err, nil) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } updatedUser.ID = user.ID updatedUser.Username = user.Username updatedUser.Filters.RecoveryCodes = user.Filters.RecoveryCodes updatedUser.Filters.TOTPConfig = user.Filters.TOTPConfig updatedUser.LastPasswordChange = user.LastPasswordChange updatedUser.SetEmptySecretsIfNil() if updatedUser.Password == redactedSecret { updatedUser.Password = user.Password } updateEncryptedSecrets(&updatedUser.FsConfig, &user.FsConfig) updatedUser = getUserFromTemplate(updatedUser, userTemplateFields{ Username: updatedUser.Username, Password: updatedUser.Password, PublicKeys: updatedUser.PublicKeys, RequirePwdChange: updatedUser.Filters.RequirePasswordChange, }) if claims.Role != "" { updatedUser.Role = claims.Role } err = dataprovider.UpdateUser(&updatedUser, claims.Username, ipAddr, claims.Role) if err != nil { s.renderUserPage(w, r, &updatedUser, userPageModeUpdate, err, nil) return } if r.Form.Get("disconnect") != "" { disconnectUser(user.Username, claims.Username, claims.Role) } http.Redirect(w, r, webUsersPath, http.StatusSeeOther) } func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := statusPage{ basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, w, r), Status: getServicesStatus(), } renderAdminTemplate(w, templateStatus, data) } func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, w, r) renderAdminTemplate(w, templateConnections, data) } func (s *httpdServer) handleWebAddFolderGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderFolderPage(w, r, vfs.BaseVirtualFolder{}, folderPageModeAdd, nil) } func (s *httpdServer) handleWebAddFolderPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } folder := vfs.BaseVirtualFolder{} err = r.ParseMultipartForm(maxRequestSize) if err != nil { s.renderFolderPage(w, r, folder, folderPageModeAdd, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } folder.MappedPath = strings.TrimSpace(r.Form.Get("mapped_path")) folder.Name = strings.TrimSpace(r.Form.Get("name")) folder.Description = r.Form.Get("description") fsConfig, err := getFsConfigFromPostFields(r) if err != nil { s.renderFolderPage(w, r, folder, folderPageModeAdd, err) return } folder.FsConfig = fsConfig folder = getFolderFromTemplate(folder, folder.Name) err = dataprovider.AddFolder(&folder, claims.Username, ipAddr, claims.Role) if err == nil { http.Redirect(w, r, webFoldersPath, http.StatusSeeOther) } else { s.renderFolderPage(w, r, folder, folderPageModeAdd, err) } } func (s *httpdServer) handleWebUpdateFolderGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) name := getURLParam(r, "name") folder, err := dataprovider.GetFolderByName(name) if err == nil { s.renderFolderPage(w, r, folder, folderPageModeUpdate, nil) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } name := getURLParam(r, "name") folder, err := dataprovider.GetFolderByName(name) if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) return } else if err != nil { s.renderInternalServerErrorPage(w, r, err) return } err = r.ParseMultipartForm(maxRequestSize) if err != nil { s.renderFolderPage(w, r, folder, folderPageModeUpdate, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } fsConfig, err := getFsConfigFromPostFields(r) if err != nil { s.renderFolderPage(w, r, folder, folderPageModeUpdate, err) return } updatedFolder := vfs.BaseVirtualFolder{ MappedPath: strings.TrimSpace(r.Form.Get("mapped_path")), Description: r.Form.Get("description"), } updatedFolder.ID = folder.ID updatedFolder.Name = folder.Name updatedFolder.FsConfig = fsConfig updatedFolder.FsConfig.SetEmptySecretsIfNil() updateEncryptedSecrets(&updatedFolder.FsConfig, &folder.FsConfig) updatedFolder = getFolderFromTemplate(updatedFolder, updatedFolder.Name) err = dataprovider.UpdateFolder(&updatedFolder, folder.Users, folder.Groups, claims.Username, ipAddr, claims.Role) if err != nil { s.renderFolderPage(w, r, updatedFolder, folderPageModeUpdate, err) return } http.Redirect(w, r, webFoldersPath, http.StatusSeeOther) } func (s *httpdServer) getWebVirtualFolders(w http.ResponseWriter, r *http.Request, limit int, minimal bool) ([]vfs.BaseVirtualFolder, error) { folders := make([]vfs.BaseVirtualFolder, 0, 50) for { f, err := dataprovider.GetFolders(limit, len(folders), dataprovider.OrderASC, minimal) if err != nil { s.renderInternalServerErrorPage(w, r, err) return folders, err } folders = append(folders, f...) if len(f) < limit { break } } return folders, nil } func getAllFolders(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) dataGetter := func(limit, offset int) ([]byte, int, error) { results, err := dataprovider.GetFolders(limit, offset, dataprovider.OrderASC, false) if err != nil { return nil, 0, err } data, err := json.Marshal(results) return data, len(results), err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleWebGetFolders(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, w, r) renderAdminTemplate(w, templateFolders, data) } func (s *httpdServer) getWebGroups(w http.ResponseWriter, r *http.Request, limit int, minimal bool) ([]dataprovider.Group, error) { groups := make([]dataprovider.Group, 0, 50) for { f, err := dataprovider.GetGroups(limit, len(groups), dataprovider.OrderASC, minimal) if err != nil { s.renderInternalServerErrorPage(w, r, err) return groups, err } groups = append(groups, f...) if len(f) < limit { break } } return groups, nil } func getAllGroups(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) dataGetter := func(limit, offset int) ([]byte, int, error) { results, err := dataprovider.GetGroups(limit, offset, dataprovider.OrderASC, false) if err != nil { return nil, 0, err } data, err := json.Marshal(results) return data, len(results), err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleWebGetGroups(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, w, r) renderAdminTemplate(w, templateGroups, data) } func (s *httpdServer) handleWebAddGroupGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderGroupPage(w, r, dataprovider.Group{}, genericPageModeAdd, nil) } func (s *httpdServer) handleWebAddGroupPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } group, err := getGroupFromPostFields(r) if err != nil { s.renderGroupPage(w, r, group, genericPageModeAdd, err) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } err = dataprovider.AddGroup(&group, claims.Username, ipAddr, claims.Role) if err != nil { s.renderGroupPage(w, r, group, genericPageModeAdd, err) return } http.Redirect(w, r, webGroupsPath, http.StatusSeeOther) } func (s *httpdServer) handleWebUpdateGroupGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) name := getURLParam(r, "name") group, err := dataprovider.GroupExists(name) if err == nil { s.renderGroupPage(w, r, group, genericPageModeUpdate, nil) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } name := getURLParam(r, "name") group, err := dataprovider.GroupExists(name) if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) return } else if err != nil { s.renderInternalServerErrorPage(w, r, err) return } updatedGroup, err := getGroupFromPostFields(r) if err != nil { s.renderGroupPage(w, r, group, genericPageModeUpdate, err) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } updatedGroup.ID = group.ID updatedGroup.Name = group.Name updatedGroup.SetEmptySecretsIfNil() updateEncryptedSecrets(&updatedGroup.UserSettings.FsConfig, &group.UserSettings.FsConfig) err = dataprovider.UpdateGroup(&updatedGroup, group.Users, claims.Username, ipAddr, claims.Role) if err != nil { s.renderGroupPage(w, r, updatedGroup, genericPageModeUpdate, err) return } http.Redirect(w, r, webGroupsPath, http.StatusSeeOther) } func (s *httpdServer) getWebEventActions(w http.ResponseWriter, r *http.Request, limit int, minimal bool, ) ([]dataprovider.BaseEventAction, error) { actions := make([]dataprovider.BaseEventAction, 0, limit) for { res, err := dataprovider.GetEventActions(limit, len(actions), dataprovider.OrderASC, minimal) if err != nil { s.renderInternalServerErrorPage(w, r, err) return actions, err } actions = append(actions, res...) if len(res) < limit { break } } return actions, nil } func getAllActions(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) dataGetter := func(limit, offset int) ([]byte, int, error) { results, err := dataprovider.GetEventActions(limit, offset, dataprovider.OrderASC, false) if err != nil { return nil, 0, err } data, err := json.Marshal(results) return data, len(results), err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleWebGetEventActions(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, w, r) renderAdminTemplate(w, templateEventActions, data) } func (s *httpdServer) handleWebAddEventActionGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) action := dataprovider.BaseEventAction{ Type: dataprovider.ActionTypeHTTP, } s.renderEventActionPage(w, r, action, genericPageModeAdd, nil) } func (s *httpdServer) handleWebAddEventActionPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } action, err := getEventActionFromPostFields(r) if err != nil { s.renderEventActionPage(w, r, action, genericPageModeAdd, err) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } if err = dataprovider.AddEventAction(&action, claims.Username, ipAddr, claims.Role); err != nil { s.renderEventActionPage(w, r, action, genericPageModeAdd, err) return } http.Redirect(w, r, webAdminEventActionsPath, http.StatusSeeOther) } func (s *httpdServer) handleWebUpdateEventActionGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) name := getURLParam(r, "name") action, err := dataprovider.EventActionExists(name) if err == nil { s.renderEventActionPage(w, r, action, genericPageModeUpdate, nil) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } name := getURLParam(r, "name") action, err := dataprovider.EventActionExists(name) if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) return } else if err != nil { s.renderInternalServerErrorPage(w, r, err) return } updatedAction, err := getEventActionFromPostFields(r) if err != nil { s.renderEventActionPage(w, r, updatedAction, genericPageModeUpdate, err) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } updatedAction.ID = action.ID updatedAction.Name = action.Name updatedAction.Options.SetEmptySecretsIfNil() switch updatedAction.Type { case dataprovider.ActionTypeHTTP: if updatedAction.Options.HTTPConfig.Password.IsNotPlainAndNotEmpty() { updatedAction.Options.HTTPConfig.Password = action.Options.HTTPConfig.Password } } err = dataprovider.UpdateEventAction(&updatedAction, claims.Username, ipAddr, claims.Role) if err != nil { s.renderEventActionPage(w, r, updatedAction, genericPageModeUpdate, err) return } http.Redirect(w, r, webAdminEventActionsPath, http.StatusSeeOther) } func getAllRules(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) dataGetter := func(limit, offset int) ([]byte, int, error) { results, err := dataprovider.GetEventRules(limit, offset, dataprovider.OrderASC) if err != nil { return nil, 0, err } data, err := json.Marshal(results) return data, len(results), err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleWebGetEventRules(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, w, r) renderAdminTemplate(w, templateEventRules, data) } func (s *httpdServer) handleWebAddEventRuleGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) rule := dataprovider.EventRule{ Status: 1, Trigger: dataprovider.EventTriggerFsEvent, } s.renderEventRulePage(w, r, rule, genericPageModeAdd, nil) } func (s *httpdServer) handleWebAddEventRulePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } rule, err := getEventRuleFromPostFields(r) if err != nil { s.renderEventRulePage(w, r, rule, genericPageModeAdd, err) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err = verifyCSRFToken(r, s.csrfTokenAuth) if err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } if err = dataprovider.AddEventRule(&rule, claims.Username, ipAddr, claims.Role); err != nil { s.renderEventRulePage(w, r, rule, genericPageModeAdd, err) return } http.Redirect(w, r, webAdminEventRulesPath, http.StatusSeeOther) } func (s *httpdServer) handleWebUpdateEventRuleGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) name := getURLParam(r, "name") rule, err := dataprovider.EventRuleExists(name) if err == nil { s.renderEventRulePage(w, r, rule, genericPageModeUpdate, nil) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } name := getURLParam(r, "name") rule, err := dataprovider.EventRuleExists(name) if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) return } else if err != nil { s.renderInternalServerErrorPage(w, r, err) return } updatedRule, err := getEventRuleFromPostFields(r) if err != nil { s.renderEventRulePage(w, r, updatedRule, genericPageModeUpdate, err) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } updatedRule.ID = rule.ID updatedRule.Name = rule.Name err = dataprovider.UpdateEventRule(&updatedRule, claims.Username, ipAddr, claims.Role) if err != nil { s.renderEventRulePage(w, r, updatedRule, genericPageModeUpdate, err) return } http.Redirect(w, r, webAdminEventRulesPath, http.StatusSeeOther) } func (s *httpdServer) getWebRoles(w http.ResponseWriter, r *http.Request, limit int, minimal bool) ([]dataprovider.Role, error) { roles := make([]dataprovider.Role, 0, 10) for { res, err := dataprovider.GetRoles(limit, len(roles), dataprovider.OrderASC, minimal) if err != nil { s.renderInternalServerErrorPage(w, r, err) return roles, err } roles = append(roles, res...) if len(res) < limit { break } } return roles, nil } func getAllRoles(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) dataGetter := func(limit, offset int) ([]byte, int, error) { results, err := dataprovider.GetRoles(limit, offset, dataprovider.OrderASC, false) if err != nil { return nil, 0, err } data, err := json.Marshal(results) return data, len(results), err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleWebGetRoles(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, w, r) renderAdminTemplate(w, templateRoles, data) } func (s *httpdServer) handleWebAddRoleGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderRolePage(w, r, dataprovider.Role{}, genericPageModeAdd, nil) } func (s *httpdServer) handleWebAddRolePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) role, err := getRoleFromPostFields(r) if err != nil { s.renderRolePage(w, r, role, genericPageModeAdd, err) return } claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } err = dataprovider.AddRole(&role, claims.Username, ipAddr, claims.Role) if err != nil { s.renderRolePage(w, r, role, genericPageModeAdd, err) return } http.Redirect(w, r, webAdminRolesPath, http.StatusSeeOther) } func (s *httpdServer) handleWebUpdateRoleGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) role, err := dataprovider.RoleExists(getURLParam(r, "name")) if err == nil { s.renderRolePage(w, r, role, genericPageModeUpdate, nil) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } role, err := dataprovider.RoleExists(getURLParam(r, "name")) if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) return } else if err != nil { s.renderInternalServerErrorPage(w, r, err) return } updatedRole, err := getRoleFromPostFields(r) if err != nil { s.renderRolePage(w, r, role, genericPageModeUpdate, err) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } updatedRole.ID = role.ID updatedRole.Name = role.Name err = dataprovider.UpdateRole(&updatedRole, claims.Username, ipAddr, claims.Role) if err != nil { s.renderRolePage(w, r, updatedRole, genericPageModeUpdate, err) return } http.Redirect(w, r, webAdminRolesPath, http.StatusSeeOther) } func (s *httpdServer) handleWebGetEvents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := eventsPage{ basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, w, r), FsEventsSearchURL: webEventsFsSearchPath, ProviderEventsSearchURL: webEventsProviderSearchPath, LogEventsSearchURL: webEventsLogSearchPath, } renderAdminTemplate(w, templateEvents, data) } func (s *httpdServer) handleWebIPListsPage(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) rtlStatus, rtlProtocols := common.Config.GetRateLimitersStatus() data := ipListsPage{ basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, w, r), RateLimitersStatus: rtlStatus, RateLimitersProtocols: strings.Join(rtlProtocols, ", "), IsAllowListEnabled: common.Config.IsAllowListEnabled(), } renderAdminTemplate(w, templateIPLists, data) } func (s *httpdServer) handleWebAddIPListEntryGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) listType, _, err := getIPListPathParams(r) if err != nil { s.renderBadRequestPage(w, r, err) return } s.renderIPListPage(w, r, dataprovider.IPListEntry{Type: listType}, genericPageModeAdd, nil) } func (s *httpdServer) handleWebAddIPListEntryPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) listType, _, err := getIPListPathParams(r) if err != nil { s.renderBadRequestPage(w, r, err) return } entry, err := getIPListEntryFromPostFields(r, listType) if err != nil { s.renderIPListPage(w, r, entry, genericPageModeAdd, err) return } entry.Type = listType claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } err = dataprovider.AddIPListEntry(&entry, claims.Username, ipAddr, claims.Role) if err != nil { s.renderIPListPage(w, r, entry, genericPageModeAdd, err) return } http.Redirect(w, r, webIPListsPath, http.StatusSeeOther) } func (s *httpdServer) handleWebUpdateIPListEntryGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) listType, ipOrNet, err := getIPListPathParams(r) if err != nil { s.renderBadRequestPage(w, r, err) return } entry, err := dataprovider.IPListEntryExists(ipOrNet, listType) if err == nil { s.renderIPListPage(w, r, entry, genericPageModeUpdate, nil) } else if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) } else { s.renderInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleWebUpdateIPListEntryPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } listType, ipOrNet, err := getIPListPathParams(r) if err != nil { s.renderBadRequestPage(w, r, err) return } entry, err := dataprovider.IPListEntryExists(ipOrNet, listType) if errors.Is(err, util.ErrNotFound) { s.renderNotFoundPage(w, r, err) return } else if err != nil { s.renderInternalServerErrorPage(w, r, err) return } updatedEntry, err := getIPListEntryFromPostFields(r, listType) if err != nil { s.renderIPListPage(w, r, entry, genericPageModeUpdate, err) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } updatedEntry.Type = listType updatedEntry.IPOrNet = ipOrNet err = dataprovider.UpdateIPListEntry(&updatedEntry, claims.Username, ipAddr, claims.Role) if err != nil { s.renderIPListPage(w, r, entry, genericPageModeUpdate, err) return } http.Redirect(w, r, webIPListsPath, http.StatusSeeOther) } func (s *httpdServer) handleWebConfigs(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) configs, err := dataprovider.GetConfigs() if err != nil { s.renderInternalServerErrorPage(w, r, err) return } s.renderConfigsPage(w, r, configs, nil, 0) } func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } configs, err := dataprovider.GetConfigs() if err != nil { s.renderInternalServerErrorPage(w, r, err) return } err = r.ParseMultipartForm(maxRequestSize) if err != nil { s.renderBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } var configSection int switch r.Form.Get("form_action") { case "sftp_submit": configSection = 1 sftpConfigs := getSFTPConfigsFromPostFields(r) configs.SFTPD = sftpConfigs case "acme_submit": configSection = 2 acmeConfigs := getACMEConfigsFromPostFields(r) configs.ACME = acmeConfigs if err := acme.GetCertificatesForConfig(acmeConfigs, configurationDir); err != nil { logger.Info(logSender, "", "unable to get ACME certificates: %v", err) s.renderConfigsPage(w, r, configs, util.NewI18nError(err, util.I18nErrorACMEGeneric), configSection) return } case "smtp_submit": configSection = 3 smtpConfigs := getSMTPConfigsFromPostFields(r) updateSMTPSecrets(smtpConfigs, configs.SMTP) configs.SMTP = smtpConfigs case "branding_submit": configSection = 4 brandingConfigs, err := getBrandingConfigFromPostFields(r, configs.Branding) configs.Branding = brandingConfigs if err != nil { logger.Info(logSender, "", "unable to get branding config: %v", err) s.renderConfigsPage(w, r, configs, err, configSection) return } default: s.renderBadRequestPage(w, r, errors.New("unsupported form action")) return } err = dataprovider.UpdateConfigs(&configs, claims.Username, ipAddr, claims.Role) if err != nil { s.renderConfigsPage(w, r, configs, err, configSection) return } postConfigsUpdate(configSection, configs) s.renderMessagePage(w, r, util.I18nConfigsTitle, http.StatusOK, nil, util.I18nConfigsOK) } func postConfigsUpdate(section int, configs dataprovider.Configs) { switch section { case 3: err := configs.SMTP.TryDecrypt() if err == nil { smtp.Activate(configs.SMTP) } else { logger.Error(logSender, "", "unable to decrypt SMTP configuration, cannot activate configuration: %v", err) } case 4: dbBrandingConfig.Set(configs.Branding) } } func (s *httpdServer) handleOAuth2TokenRedirect(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) stateToken := r.URL.Query().Get("state") state, err := verifyOAuth2Token(s.csrfTokenAuth, stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr)) if err != nil { s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, err, "") return } pendingAuth, err := oauth2Mgr.getPendingAuth(state) if err != nil { oauth2Mgr.removePendingAuth(state) s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusInternalServerError, util.NewI18nError(err, util.I18nOAuth2ErrorValidateState), "") return } oauth2Mgr.removePendingAuth(state) oauth2Config := smtp.OAuth2Config{ Provider: pendingAuth.Provider, ClientID: pendingAuth.ClientID, ClientSecret: pendingAuth.ClientSecret.GetPayload(), } ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() cfg := oauth2Config.GetOAuth2() cfg.RedirectURL = pendingAuth.RedirectURL token, err := cfg.Exchange(ctx, r.URL.Query().Get("code"), oauth2.VerifierOption(pendingAuth.Verifier)) if err != nil { s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusInternalServerError, util.NewI18nError(err, util.I18nOAuth2ErrTokenExchange), "") return } if token.RefreshToken == "" { errTxt := "the OAuth2 provider returned an empty token. " + "Some providers only return the token when the user first authorizes. " + "If you have already registered SFTPGo with this user in the past, revoke access and try again. " + "This way you will invalidate the previous token" s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, util.NewI18nError(errors.New(errTxt), util.I18nOAuth2ErrNoRefreshToken), "") return } s.renderMessagePageWithString(w, r, util.I18nOAuth2Title, http.StatusOK, nil, util.I18nOAuth2OK, fmt.Sprintf("%q", token.RefreshToken)) } func updateSMTPSecrets(newConfigs, currentConfigs *dataprovider.SMTPConfigs) { if currentConfigs == nil { currentConfigs = &dataprovider.SMTPConfigs{} } if newConfigs.Password.IsNotPlainAndNotEmpty() { newConfigs.Password = currentConfigs.Password } if newConfigs.OAuth2.ClientSecret.IsNotPlainAndNotEmpty() { newConfigs.OAuth2.ClientSecret = currentConfigs.OAuth2.ClientSecret } if newConfigs.OAuth2.RefreshToken.IsNotPlainAndNotEmpty() { newConfigs.OAuth2.RefreshToken = currentConfigs.OAuth2.RefreshToken } } ================================================ FILE: internal/httpd/webclient.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "bytes" "crypto/rand" "encoding/json" "errors" "fmt" "html/template" "io" "math" "net/http" "net/url" "os" "path" "path/filepath" "slices" "strconv" "strings" "time" "github.com/go-chi/render" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( templateClientDir = "webclient" templateClientBase = "base.html" templateClientFiles = "files.html" templateClientProfile = "profile.html" templateClientMFA = "mfa.html" templateClientEditFile = "editfile.html" templateClientShare = "share.html" templateClientShares = "shares.html" templateClientViewPDF = "viewpdf.html" templateShareLogin = "sharelogin.html" templateShareDownload = "sharedownload.html" templateUploadToShare = "shareupload.html" ) // condResult is the result of an HTTP request precondition check. // See https://tools.ietf.org/html/rfc7232 section 3. type condResult int const ( condNone condResult = iota condTrue condFalse ) var ( clientTemplates = make(map[string]*template.Template) unixEpochTime = time.Unix(0, 0) ) // isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). func isZeroTime(t time.Time) bool { return t.IsZero() || t.Equal(unixEpochTime) } type baseClientPage struct { commonBasePage Title string CurrentURL string FilesURL string SharesURL string ShareURL string ProfileURL string PingURL string ChangePwdURL string LogoutURL string LoginURL string EditURL string MFAURL string CSRFToken string LoggedUser *dataprovider.User IsLoggedToShare bool Branding UIBranding Languages []string } type dirMapping struct { DirName string Href string } type viewPDFPage struct { commonBasePage Title string URL string Branding UIBranding Languages []string } type editFilePage struct { baseClientPage CurrentDir string FileURL string Path string Name string ReadOnly bool Data string } type filesPage struct { baseClientPage CurrentDir string DirsURL string FileActionsURL string CheckExistURL string DownloadURL string ViewPDFURL string FileURL string TasksURL string CanAddFiles bool CanCreateDirs bool CanRename bool CanDelete bool CanDownload bool CanShare bool CanCopy bool ShareUploadBaseURL string Error *util.I18nError Paths []dirMapping QuotaUsage *userQuotaUsage KeepAliveInterval int } type shareLoginPage struct { commonBasePage CurrentURL string Error *util.I18nError CSRFToken string Title string Branding UIBranding Languages []string CheckRedirect bool } type shareDownloadPage struct { baseClientPage DownloadLink string } type shareUploadPage struct { baseClientPage Share *dataprovider.Share UploadBasePath string } type clientMessagePage struct { baseClientPage Error *util.I18nError Success string Text string } type clientProfilePage struct { baseClientPage PublicKeys []string TLSCerts []string CanSubmit bool AllowAPIKeyAuth bool Email string AdditionalEmails []string AdditionalEmailsString string Description string Error *util.I18nError } type changeClientPasswordPage struct { baseClientPage Error *util.I18nError } type clientMFAPage struct { baseClientPage TOTPConfigs []string TOTPConfig dataprovider.UserTOTPConfig GenerateTOTPURL string ValidateTOTPURL string SaveTOTPURL string RecCodesURL string Protocols []string RequiredProtocols []string } type clientSharesPage struct { baseClientPage BasePublicSharesURL string BaseURL string } type clientSharePage struct { baseClientPage Share *dataprovider.Share Error *util.I18nError IsAdd bool } type userQuotaUsage struct { QuotaSize int64 QuotaFiles int UsedQuotaSize int64 UsedQuotaFiles int UploadDataTransfer int64 DownloadDataTransfer int64 TotalDataTransfer int64 UsedUploadDataTransfer int64 UsedDownloadDataTransfer int64 } func (u *userQuotaUsage) HasQuotaInfo() bool { if dataprovider.GetQuotaTracking() == 0 { return false } if u.HasDiskQuota() { return true } return u.HasTranferQuota() } func (u *userQuotaUsage) HasDiskQuota() bool { if u.QuotaSize > 0 || u.UsedQuotaSize > 0 { return true } return u.QuotaFiles > 0 || u.UsedQuotaFiles > 0 } func (u *userQuotaUsage) HasTranferQuota() bool { if u.TotalDataTransfer > 0 || u.UploadDataTransfer > 0 || u.DownloadDataTransfer > 0 { return true } return u.UsedDownloadDataTransfer > 0 || u.UsedUploadDataTransfer > 0 } func (u *userQuotaUsage) GetQuotaSize() string { if u.QuotaSize > 0 { return fmt.Sprintf("%s/%s", util.ByteCountIEC(u.UsedQuotaSize), util.ByteCountIEC(u.QuotaSize)) } if u.UsedQuotaSize > 0 { return util.ByteCountIEC(u.UsedQuotaSize) } return "" } func (u *userQuotaUsage) GetQuotaFiles() string { if u.QuotaFiles > 0 { return fmt.Sprintf("%d/%d", u.UsedQuotaFiles, u.QuotaFiles) } if u.UsedQuotaFiles > 0 { return strconv.FormatInt(int64(u.UsedQuotaFiles), 10) } return "" } func (u *userQuotaUsage) GetQuotaSizePercentage() int { if u.QuotaSize > 0 { return int(math.Round(100 * float64(u.UsedQuotaSize) / float64(u.QuotaSize))) } return 0 } func (u *userQuotaUsage) GetQuotaFilesPercentage() int { if u.QuotaFiles > 0 { return int(math.Round(100 * float64(u.UsedQuotaFiles) / float64(u.QuotaFiles))) } return 0 } func (u *userQuotaUsage) IsQuotaSizeLow() bool { return u.GetQuotaSizePercentage() > 85 } func (u *userQuotaUsage) IsQuotaFilesLow() bool { return u.GetQuotaFilesPercentage() > 85 } func (u *userQuotaUsage) IsDiskQuotaLow() bool { return u.IsQuotaSizeLow() || u.IsQuotaFilesLow() } func (u *userQuotaUsage) GetTotalTransferQuota() string { total := u.UsedUploadDataTransfer + u.UsedDownloadDataTransfer if u.TotalDataTransfer > 0 { return fmt.Sprintf("%s/%s", util.ByteCountIEC(total), util.ByteCountIEC(u.TotalDataTransfer*1048576)) } if total > 0 { return util.ByteCountIEC(total) } return "" } func (u *userQuotaUsage) GetUploadTransferQuota() string { if u.UploadDataTransfer > 0 { return fmt.Sprintf("%s/%s", util.ByteCountIEC(u.UsedUploadDataTransfer), util.ByteCountIEC(u.UploadDataTransfer*1048576)) } if u.UsedUploadDataTransfer > 0 { return util.ByteCountIEC(u.UsedUploadDataTransfer) } return "" } func (u *userQuotaUsage) GetDownloadTransferQuota() string { if u.DownloadDataTransfer > 0 { return fmt.Sprintf("%s/%s", util.ByteCountIEC(u.UsedDownloadDataTransfer), util.ByteCountIEC(u.DownloadDataTransfer*1048576)) } if u.UsedDownloadDataTransfer > 0 { return util.ByteCountIEC(u.UsedDownloadDataTransfer) } return "" } func (u *userQuotaUsage) GetTotalTransferQuotaPercentage() int { if u.TotalDataTransfer > 0 { return int(math.Round(100 * float64(u.UsedDownloadDataTransfer+u.UsedUploadDataTransfer) / float64(u.TotalDataTransfer*1048576))) } return 0 } func (u *userQuotaUsage) GetUploadTransferQuotaPercentage() int { if u.UploadDataTransfer > 0 { return int(math.Round(100 * float64(u.UsedUploadDataTransfer) / float64(u.UploadDataTransfer*1048576))) } return 0 } func (u *userQuotaUsage) GetDownloadTransferQuotaPercentage() int { if u.DownloadDataTransfer > 0 { return int(math.Round(100 * float64(u.UsedDownloadDataTransfer) / float64(u.DownloadDataTransfer*1048576))) } return 0 } func (u *userQuotaUsage) IsTotalTransferQuotaLow() bool { if u.TotalDataTransfer > 0 { return u.GetTotalTransferQuotaPercentage() > 85 } return false } func (u *userQuotaUsage) IsUploadTransferQuotaLow() bool { if u.UploadDataTransfer > 0 { return u.GetUploadTransferQuotaPercentage() > 85 } return false } func (u *userQuotaUsage) IsDownloadTransferQuotaLow() bool { if u.DownloadDataTransfer > 0 { return u.GetDownloadTransferQuotaPercentage() > 85 } return false } func (u *userQuotaUsage) IsTransferQuotaLow() bool { return u.IsTotalTransferQuotaLow() || u.IsUploadTransferQuotaLow() || u.IsDownloadTransferQuotaLow() } func (u *userQuotaUsage) IsQuotaLow() bool { return u.IsDiskQuotaLow() || u.IsTransferQuotaLow() } func newUserQuotaUsage(u *dataprovider.User) *userQuotaUsage { return &userQuotaUsage{ QuotaSize: u.QuotaSize, QuotaFiles: u.QuotaFiles, UsedQuotaSize: u.UsedQuotaSize, UsedQuotaFiles: u.UsedQuotaFiles, TotalDataTransfer: u.TotalDataTransfer, UploadDataTransfer: u.UploadDataTransfer, DownloadDataTransfer: u.DownloadDataTransfer, UsedUploadDataTransfer: u.UsedUploadDataTransfer, UsedDownloadDataTransfer: u.UsedDownloadDataTransfer, } } func getFileObjectURL(baseDir, name, baseWebPath string) string { return fmt.Sprintf("%v?path=%v&_=%v", baseWebPath, url.QueryEscape(path.Join(baseDir, name)), time.Now().UTC().Unix()) } func getFileObjectModTime(t time.Time) int64 { if isZeroTime(t) { return 0 } return t.UnixMilli() } func loadClientTemplates(templatesPath string) { filesPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateClientDir, templateClientFiles), } editFilePath := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateClientDir, templateClientEditFile), } sharesPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateClientDir, templateClientShares), } sharePaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateClientDir, templateClientShare), } profilePaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateClientDir, templateClientProfile), } changePwdPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateCommonDir, templateChangePwd), } loginPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateCommonLogin), } messagePaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateCommonDir, templateMessage), } mfaPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateClientDir, templateClientMFA), } twoFactorPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateTwoFactor), } twoFactorRecoveryPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateTwoFactorRecovery), } forgotPwdPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateForgotPassword), } resetPwdPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateCommonDir, templateResetPassword), } viewPDFPaths := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientViewPDF), } shareLoginPath := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), filepath.Join(templatesPath, templateClientDir, templateShareLogin), } shareUploadPath := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateClientDir, templateUploadToShare), } shareDownloadPath := []string{ filepath.Join(templatesPath, templateCommonDir, templateCommonBase), filepath.Join(templatesPath, templateClientDir, templateClientBase), filepath.Join(templatesPath, templateClientDir, templateShareDownload), } filesTmpl := util.LoadTemplate(nil, filesPaths...) profileTmpl := util.LoadTemplate(nil, profilePaths...) changePwdTmpl := util.LoadTemplate(nil, changePwdPaths...) loginTmpl := util.LoadTemplate(nil, loginPaths...) messageTmpl := util.LoadTemplate(nil, messagePaths...) mfaTmpl := util.LoadTemplate(nil, mfaPaths...) twoFactorTmpl := util.LoadTemplate(nil, twoFactorPaths...) twoFactorRecoveryTmpl := util.LoadTemplate(nil, twoFactorRecoveryPaths...) editFileTmpl := util.LoadTemplate(nil, editFilePath...) shareLoginTmpl := util.LoadTemplate(nil, shareLoginPath...) sharesTmpl := util.LoadTemplate(nil, sharesPaths...) shareTmpl := util.LoadTemplate(nil, sharePaths...) forgotPwdTmpl := util.LoadTemplate(nil, forgotPwdPaths...) resetPwdTmpl := util.LoadTemplate(nil, resetPwdPaths...) viewPDFTmpl := util.LoadTemplate(nil, viewPDFPaths...) shareUploadTmpl := util.LoadTemplate(nil, shareUploadPath...) shareDownloadTmpl := util.LoadTemplate(nil, shareDownloadPath...) clientTemplates[templateClientFiles] = filesTmpl clientTemplates[templateClientProfile] = profileTmpl clientTemplates[templateChangePwd] = changePwdTmpl clientTemplates[templateCommonLogin] = loginTmpl clientTemplates[templateMessage] = messageTmpl clientTemplates[templateClientMFA] = mfaTmpl clientTemplates[templateTwoFactor] = twoFactorTmpl clientTemplates[templateTwoFactorRecovery] = twoFactorRecoveryTmpl clientTemplates[templateClientEditFile] = editFileTmpl clientTemplates[templateClientShares] = sharesTmpl clientTemplates[templateClientShare] = shareTmpl clientTemplates[templateForgotPassword] = forgotPwdTmpl clientTemplates[templateResetPassword] = resetPwdTmpl clientTemplates[templateClientViewPDF] = viewPDFTmpl clientTemplates[templateShareLogin] = shareLoginTmpl clientTemplates[templateUploadToShare] = shareUploadTmpl clientTemplates[templateShareDownload] = shareDownloadTmpl } func (s *httpdServer) getBaseClientPageData(title, currentURL string, w http.ResponseWriter, r *http.Request) baseClientPage { var csrfToken string if currentURL != "" { csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath) } data := baseClientPage{ commonBasePage: getCommonBasePage(r), Title: title, CurrentURL: currentURL, FilesURL: webClientFilesPath, SharesURL: webClientSharesPath, ShareURL: webClientSharePath, ProfileURL: webClientProfilePath, PingURL: webClientPingPath, ChangePwdURL: webChangeClientPwdPath, LogoutURL: webClientLogoutPath, EditURL: webClientEditFilePath, MFAURL: webClientMFAPath, CSRFToken: csrfToken, LoggedUser: getUserFromToken(r), IsLoggedToShare: false, Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), } if !strings.HasPrefix(r.RequestURI, webClientPubSharesPath) { data.LoginURL = webClientLoginPath } return data } func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := forgotPwdPage{ commonBasePage: getCommonBasePage(r), CurrentURL: webClientForgotPwdPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), LoginURL: webClientLoginPath, Title: util.I18nForgotPwdTitle, Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), } renderClientTemplate(w, templateForgotPassword, data) } func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := resetPwdPage{ commonBasePage: getCommonBasePage(r), CurrentURL: webClientResetPwdPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath), LoginURL: webClientLoginPath, Title: util.I18nResetPwdTitle, Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), } renderClientTemplate(w, templateResetPassword, data) } func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := shareLoginPage{ commonBasePage: getCommonBasePage(r), Title: util.I18nShareLoginTitle, CurrentURL: r.RequestURI, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), CheckRedirect: false, } renderClientTemplate(w, templateShareLogin, data) } func renderClientTemplate(w http.ResponseWriter, tmplName string, data any) { err := clientTemplates[tmplName].ExecuteTemplate(w, tmplName, data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } func (s *httpdServer) renderClientMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message string) { data := clientMessagePage{ baseClientPage: s.getBaseClientPageData(title, "", w, r), Error: getI18nError(err), Success: message, } w.WriteHeader(statusCode) renderClientTemplate(w, templateMessage, data) } func (s *httpdServer) renderClientInternalServerErrorPage(w http.ResponseWriter, r *http.Request, err error) { s.renderClientMessagePage(w, r, util.I18nError500Title, http.StatusInternalServerError, util.NewI18nError(err, util.I18nError500Message), "") } func (s *httpdServer) renderClientBadRequestPage(w http.ResponseWriter, r *http.Request, err error) { s.renderClientMessagePage(w, r, util.I18nError400Title, http.StatusBadRequest, util.NewI18nError(err, util.I18nError400Message), "") } func (s *httpdServer) renderClientForbiddenPage(w http.ResponseWriter, r *http.Request, err error) { s.renderClientMessagePage(w, r, util.I18nError403Title, http.StatusForbidden, util.NewI18nError(err, util.I18nError403Message), "") } func (s *httpdServer) renderClientNotFoundPage(w http.ResponseWriter, r *http.Request, err error) { s.renderClientMessagePage(w, r, util.I18nError404Title, http.StatusNotFound, util.NewI18nError(err, util.I18nError404Message), "") } func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := twoFactorPage{ commonBasePage: getCommonBasePage(r), Title: util.I18n2FATitle, CurrentURL: webClientTwoFactorPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath), RecoveryURL: webClientTwoFactorRecoveryPath, Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), } if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { data.CurrentURL += "?next=" + url.QueryEscape(next) } renderClientTemplate(w, templateTwoFactor, data) } func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := twoFactorPage{ commonBasePage: getCommonBasePage(r), Title: util.I18n2FATitle, CurrentURL: webClientTwoFactorRecoveryPath, Error: err, CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath), Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), } renderClientTemplate(w, templateTwoFactorRecovery, data) } func (s *httpdServer) renderClientMFAPage(w http.ResponseWriter, r *http.Request) { data := clientMFAPage{ baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, w, r), TOTPConfigs: mfa.GetAvailableTOTPConfigNames(), GenerateTOTPURL: webClientTOTPGeneratePath, ValidateTOTPURL: webClientTOTPValidatePath, SaveTOTPURL: webClientTOTPSavePath, RecCodesURL: webClientRecoveryCodesPath, Protocols: dataprovider.MFAProtocols, } user, err := dataprovider.GetUserWithGroupSettings(data.LoggedUser.Username, "") if err != nil { s.renderClientInternalServerErrorPage(w, r, err) return } data.TOTPConfig = user.Filters.TOTPConfig data.RequiredProtocols = user.Filters.TwoFactorAuthProtocols renderClientTemplate(w, templateClientMFA, data) } func (s *httpdServer) renderEditFilePage(w http.ResponseWriter, r *http.Request, fileName, fileData string, readOnly bool) { title := util.I18nViewFileTitle if !readOnly { title = util.I18nEditFileTitle } data := editFilePage{ baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, w, r), Path: fileName, Name: path.Base(fileName), CurrentDir: path.Dir(fileName), FileURL: webClientFilePath, ReadOnly: readOnly, Data: fileData, } renderClientTemplate(w, templateClientEditFile, data) } func (s *httpdServer) renderAddUpdateSharePage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share, err *util.I18nError, isAdd bool) { currentURL := webClientSharePath title := util.I18nShareAddTitle if !isAdd { currentURL = fmt.Sprintf("%v/%v", webClientSharePath, url.PathEscape(share.ShareID)) title = util.I18nShareUpdateTitle } if share.IsPasswordHashed() { share.Password = redactedSecret } data := clientSharePage{ baseClientPage: s.getBaseClientPageData(title, currentURL, w, r), Share: share, Error: err, IsAdd: isAdd, } renderClientTemplate(w, templateClientShare, data) } func getDirMapping(dirName, baseWebPath string) []dirMapping { paths := []dirMapping{} if dirName != "/" { paths = append(paths, dirMapping{ DirName: path.Base(dirName), Href: getFileObjectURL("/", dirName, baseWebPath), }) for { dirName = path.Dir(dirName) if dirName == "/" || dirName == "." { break } paths = append([]dirMapping{{ DirName: path.Base(dirName), Href: getFileObjectURL("/", dirName, baseWebPath)}, }, paths...) } } return paths } func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Request, dirName string, err *util.I18nError, share dataprovider.Share, ) { currentURL := path.Join(webClientPubSharesPath, share.ShareID, "browse") baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, w, r) baseData.FilesURL = currentURL baseSharePath := path.Join(webClientPubSharesPath, share.ShareID) baseData.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout") baseData.IsLoggedToShare = share.Password != "" data := filesPage{ baseClientPage: baseData, Error: err, CurrentDir: url.QueryEscape(dirName), DownloadURL: path.Join(baseSharePath, "partial"), // dirName must be escaped because the router expects the full path as single argument ShareUploadBaseURL: path.Join(baseSharePath, url.PathEscape(dirName)), ViewPDFURL: path.Join(baseSharePath, "viewpdf"), DirsURL: path.Join(baseSharePath, "dirs"), FileURL: "", FileActionsURL: "", CheckExistURL: path.Join(baseSharePath, "browse", "exist"), TasksURL: "", CanAddFiles: share.Scope == dataprovider.ShareScopeReadWrite, CanCreateDirs: false, CanRename: false, CanDelete: false, CanDownload: share.Scope != dataprovider.ShareScopeWrite, CanShare: false, CanCopy: false, Paths: getDirMapping(dirName, currentURL), QuotaUsage: newUserQuotaUsage(&dataprovider.User{}), KeepAliveInterval: int(cookieRefreshThreshold / time.Millisecond), } renderClientTemplate(w, templateClientFiles, data) } func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share, downloadLink string, ) { data := shareDownloadPage{ baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", w, r), DownloadLink: downloadLink, } data.LogoutURL = "" if share.Password != "" { data.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout") } renderClientTemplate(w, templateShareDownload, data) } func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) { currentURL := path.Join(webClientPubSharesPath, share.ShareID, "upload") data := shareUploadPage{ baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, w, r), Share: share, UploadBasePath: path.Join(webClientPubSharesPath, share.ShareID), } data.LogoutURL = "" if share.Password != "" { data.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout") } renderClientTemplate(w, templateUploadToShare, data) } func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, dirName string, err *util.I18nError, user *dataprovider.User) { data := filesPage{ baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, w, r), Error: err, CurrentDir: url.QueryEscape(dirName), DownloadURL: webClientDownloadZipPath, ViewPDFURL: webClientViewPDFPath, DirsURL: webClientDirsPath, FileURL: webClientFilePath, FileActionsURL: webClientFileActionsPath, CheckExistURL: webClientExistPath, TasksURL: webClientTasksPath, CanAddFiles: user.CanAddFilesFromWeb(dirName), CanCreateDirs: user.CanAddDirsFromWeb(dirName), CanRename: user.CanRenameFromWeb(dirName, dirName), CanDelete: user.CanDeleteFromWeb(dirName), CanDownload: user.HasPerm(dataprovider.PermDownload, dirName), CanShare: user.CanManageShares(), CanCopy: user.CanCopyFromWeb(dirName, dirName), ShareUploadBaseURL: "", Paths: getDirMapping(dirName, webClientFilesPath), QuotaUsage: newUserQuotaUsage(user), KeepAliveInterval: int(cookieRefreshThreshold / time.Millisecond), } renderClientTemplate(w, templateClientFiles, data) } func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := clientProfilePage{ baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, w, r), Error: err, } user, userMerged, errUser := dataprovider.GetUserVariants(data.LoggedUser.Username, "") if errUser != nil { s.renderClientInternalServerErrorPage(w, r, errUser) return } data.PublicKeys = user.PublicKeys data.TLSCerts = user.Filters.TLSCerts data.AllowAPIKeyAuth = user.Filters.AllowAPIKeyAuth data.Email = user.Email data.AdditionalEmails = user.Filters.AdditionalEmails data.AdditionalEmailsString = strings.Join(data.AdditionalEmails, ", ") data.Description = user.Description data.CanSubmit = userMerged.CanUpdateProfile() renderClientTemplate(w, templateClientProfile, data) } func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := changeClientPasswordPage{ baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, w, r), Error: err, } renderClientTemplate(w, templateChangePwd, data) } func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxMultipartMem) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } if err := r.ParseForm(); err != nil { s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), util.NewI18nError(err, util.I18nErrorGetUser), "") return } connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { s.renderClientForbiddenPage(w, r, err) return } baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) files := r.Form.Get("files") var filesList []string err = json.Unmarshal(util.StringToBytes(files), &filesList) if err != nil { s.renderClientBadRequestPage(w, r, err) return } w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", getCompressedFileName(connection.GetUsername(), filesList))) renderCompressedFiles(w, connection, name, filesList, nil) } func (s *httpdServer) handleClientSharePartialDownload(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxMultipartMem) if err := r.ParseForm(); err != nil { s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err := validateBrowsableShare(share, connection); err != nil { s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") return } name, err := getBrowsableSharedPath(share.Paths[0], r) if err != nil { s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") return } if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") return } defer common.Connections.Remove(connection.GetID()) transferQuota := connection.GetTransferQuota() if !transferQuota.HasDownloadSpace() { err = util.NewI18nError(connection.GetReadQuotaExceededError(), util.I18nErrorQuotaRead) connection.Log(logger.LevelInfo, "denying share read due to quota limits") s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getMappedStatusCode(err), err, "") return } files := r.Form.Get("files") var filesList []string err = json.Unmarshal(util.StringToBytes(files), &filesList) if err != nil { s.renderClientBadRequestPage(w, r, err) return } dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", getCompressedFileName(fmt.Sprintf("share-%s", share.Name), filesList))) renderCompressedFiles(w, connection, name, filesList, &share) } func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err := validateBrowsableShare(share, connection); err != nil { sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nError500Message), getRespStatus(err)) return } name, err := getBrowsableSharedPath(share.Paths[0], r) if err != nil { sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nError500Message), getRespStatus(err)) return } if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nError429Message), http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) lister, err := connection.ReadDir(name) if err != nil { sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirListGeneric), getMappedStatusCode(err)) return } defer lister.Close() dataGetter := func(limit, offset int) ([]byte, int, error) { contents, err := lister.Next(limit) if errors.Is(err, io.EOF) { err = nil } if err != nil { return nil, 0, err } results := make([]map[string]any, 0, len(contents)) for idx, info := range contents { if !info.Mode().IsDir() && !info.Mode().IsRegular() { continue } res := make(map[string]any) res["id"] = offset + idx + 1 if info.IsDir() { res["type"] = "1" res["size"] = "" } else { res["type"] = "2" res["size"] = info.Size() } res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) res["name"] = info.Name() res["url"] = getFileObjectURL(share.GetRelativePath(name), info.Name(), path.Join(webClientPubSharesPath, share.ShareID, "browse")) res["last_modified"] = getFileObjectModTime(info.ModTime()) results = append(results, res) } data, err := json.Marshal(results) count := limit if len(results) == 0 { count = 0 } return data, count, err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleClientUploadToShare(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeWrite, dataprovider.ShareScopeReadWrite} share, _, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if share.Scope == dataprovider.ShareScopeReadWrite { http.Redirect(w, r, path.Join(webClientPubSharesPath, share.ShareID, "browse"), http.StatusFound) return } s.renderUploadToSharePage(w, r, &share) } func (s *httpdServer) handleShareGetFiles(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err := validateBrowsableShare(share, connection); err != nil { s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") return } name, err := getBrowsableSharedPath(share.Paths[0], r) if err != nil { s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") return } if err = common.Connections.Add(connection); err != nil { s.renderSharedFilesPage(w, r, path.Dir(share.GetRelativePath(name)), util.NewI18nError(err, util.I18nError429Message), share) return } defer common.Connections.Remove(connection.GetID()) var info os.FileInfo if name == "/" { info = vfs.NewFileInfo(name, true, 0, time.Unix(0, 0), false) } else { info, err = connection.Stat(name, 1) } if err != nil { s.renderSharedFilesPage(w, r, path.Dir(share.GetRelativePath(name)), util.NewI18nError(err, i18nFsMsg(getRespStatus(err))), share) return } if info.IsDir() { s.renderSharedFilesPage(w, r, share.GetRelativePath(name), nil, share) return } dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck if status, err := downloadFile(w, r, connection, name, info, false, &share); err != nil { dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck if status > 0 { s.renderSharedFilesPage(w, r, path.Dir(share.GetRelativePath(name)), util.NewI18nError(err, i18nFsMsg(getRespStatus(err))), share) } } } func (s *httpdServer) handleShareViewPDF(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} share, _, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } name := util.CleanPath(r.URL.Query().Get("path")) data := viewPDFPage{ commonBasePage: getCommonBasePage(r), Title: path.Base(name), URL: fmt.Sprintf("%s?path=%s&_=%d", path.Join(webClientPubSharesPath, share.ShareID, "getpdf"), url.QueryEscape(name), time.Now().UTC().Unix()), Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), } renderClientTemplate(w, templateClientViewPDF, data) } func (s *httpdServer) handleShareGetPDF(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err := validateBrowsableShare(share, connection); err != nil { s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") return } name, err := getBrowsableSharedPath(share.Paths[0], r) if err != nil { s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") return } if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") return } defer common.Connections.Remove(connection.GetID()) info, err := connection.Stat(name, 1) if err != nil { status := getRespStatus(err) s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, status, util.NewI18nError(err, i18nFsMsg(status)), "") return } if info.IsDir() { s.renderClientBadRequestPage(w, r, util.NewI18nError(fmt.Errorf("%q is not a file", name), util.I18nErrorPDFMessage)) return } connection.User.CheckFsRoot(connection.ID) //nolint:errcheck if err := s.ensurePDF(w, r, name, connection); err != nil { return } dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck if _, err := downloadFile(w, r, connection, name, info, true, &share); err != nil { dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck } } func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, nil, util.I18nErrorDirList403, http.StatusForbidden) return } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { sendAPIResponse(w, r, nil, util.I18nErrorDirListUser, getRespStatus(err)) return } connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%s_%s", protocol, connID) if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirList403), http.StatusForbidden) return } baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, util.I18nErrorDirList429, http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) lister, err := connection.ReadDir(name) if err != nil { statusCode := getMappedStatusCode(err) sendAPIResponse(w, r, err, i18nListDirMsg(statusCode), statusCode) return } defer lister.Close() dirTree := r.URL.Query().Get("dirtree") == "1" dataGetter := func(limit, offset int) ([]byte, int, error) { contents, err := lister.Next(limit) if errors.Is(err, io.EOF) { err = nil } if err != nil { return nil, 0, err } results := make([]map[string]any, 0, len(contents)) for idx, info := range contents { res := make(map[string]any) res["id"] = offset + idx + 1 res["url"] = getFileObjectURL(name, info.Name(), webClientFilesPath) if info.IsDir() { res["type"] = "1" res["size"] = "" res["dir_path"] = url.QueryEscape(path.Join(name, info.Name())) } else { if dirTree { continue } res["type"] = "2" if info.Mode()&os.ModeSymlink != 0 { res["size"] = "" } else { res["size"] = info.Size() if info.Size() < httpdMaxEditFileSize { res["edit_url"] = strings.Replace(res["url"].(string), webClientFilesPath, webClientEditFilePath, 1) } } } res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) res["name"] = info.Name() res["last_modified"] = getFileObjectModTime(info.ModTime()) results = append(results, res) } data, err := json.Marshal(results) count := limit if len(results) == 0 { count = 0 } return data, count, err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), util.NewI18nError(err, util.I18nErrorGetUser), "") return } connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { s.renderClientForbiddenPage(w, r, err) return } baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) var info os.FileInfo if name == "/" { info = vfs.NewFileInfo(name, true, 0, time.Unix(0, 0), false) } else { info, err = connection.Stat(name, 0) } if err != nil { s.renderFilesPage(w, r, path.Dir(name), util.NewI18nError(err, i18nFsMsg(getRespStatus(err))), &user) return } if info.IsDir() { s.renderFilesPage(w, r, name, nil, &user) return } if status, err := downloadFile(w, r, connection, name, info, false, nil); err != nil && status != 0 { if status > 0 { if status == http.StatusRequestedRangeNotSatisfiable { s.renderClientMessagePage(w, r, util.I18nError416Title, status, util.NewI18nError(err, util.I18nError416Message), "") return } s.renderFilesPage(w, r, path.Dir(name), util.NewI18nError(err, i18nFsMsg(status)), &user) } } } func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), util.NewI18nError(err, util.I18nErrorGetUser), "") return } connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { s.renderClientForbiddenPage(w, r, err) return } baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) info, err := connection.Stat(name, 0) if err != nil { status := getRespStatus(err) s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, status, util.NewI18nError(err, i18nFsMsg(status)), "") return } if info.IsDir() { s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, http.StatusBadRequest, util.NewI18nError( util.NewValidationError(fmt.Sprintf("The path %q does not point to a file", name)), util.I18nErrorEditDir, ), "") return } if info.Size() > httpdMaxEditFileSize { s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, http.StatusBadRequest, util.NewI18nError( util.NewValidationError(fmt.Sprintf("The file size %v for %q exceeds the maximum allowed size", util.ByteCountIEC(info.Size()), name)), util.I18nErrorEditSize, ), "") return } connection.User.CheckFsRoot(connection.ID) //nolint:errcheck reader, err := connection.getFileReader(name, 0, r.Method) if err != nil { s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, getRespStatus(err), util.NewI18nError(err, util.I18nError500Message), "") return } defer reader.Close() var b bytes.Buffer _, err = io.Copy(&b, reader) if err != nil { s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, getRespStatus(err), util.NewI18nError(err, util.I18nError500Message), "") return } s.renderEditFilePage(w, r, name, b.String(), !user.CanAddFilesFromWeb(path.Dir(name))) } func (s *httpdServer) handleClientAddShareGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), util.NewI18nError(err, util.I18nErrorGetUser), "") return } share := &dataprovider.Share{Scope: dataprovider.ShareScopeRead} if user.Filters.DefaultSharesExpiration > 0 { share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(user.Filters.DefaultSharesExpiration))) } else if user.Filters.MaxSharesExpiration > 0 { share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(user.Filters.MaxSharesExpiration))) } dirName := "/" if _, ok := r.URL.Query()["path"]; ok { dirName = util.CleanPath(r.URL.Query().Get("path")) } if _, ok := r.URL.Query()["files"]; ok { files := r.URL.Query().Get("files") var filesList []string err := json.Unmarshal(util.StringToBytes(files), &filesList) if err != nil { s.renderClientBadRequestPage(w, r, err) return } for _, f := range filesList { if f != "" { share.Paths = append(share.Paths, path.Join(dirName, f)) } } } s.renderAddUpdateSharePage(w, r, share, nil, true) } func (s *httpdServer) handleClientUpdateShareGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } shareID := getURLParam(r, "id") share, err := dataprovider.ShareExists(shareID, claims.Username) if err == nil { s.renderAddUpdateSharePage(w, r, &share, nil, false) } else if errors.Is(err, util.ErrNotFound) { s.renderClientNotFoundPage(w, r, err) } else { s.renderClientInternalServerErrorPage(w, r, err) } } func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } share, err := getShareFromPostFields(r) if err != nil { s.renderAddUpdateSharePage(w, r, share, util.NewI18nError(err, util.I18nError500Message), true) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } share.ID = 0 share.ShareID = util.GenerateUniqueID() share.LastUseAt = 0 share.Username = claims.Username if share.Password == "" { if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { s.renderAddUpdateSharePage(w, r, share, util.NewI18nError(util.NewValidationError("You are not allowed to share files/folders without password"), util.I18nErrorShareNoPwd), true) return } } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { s.renderAddUpdateSharePage(w, r, share, util.NewI18nError(err, util.I18nErrorGetUser), true) return } if err := user.CheckMaxShareExpiration(util.GetTimeFromMsecSinceEpoch(share.ExpiresAt)); err != nil { s.renderAddUpdateSharePage(w, r, share, util.NewI18nError( err, util.I18nErrorShareExpirationOutOfRange, util.I18nErrorArgs( map[string]any{ "val": time.Now().Add(24 * time.Hour * time.Duration(user.Filters.MaxSharesExpiration+1)).UnixMilli(), "formatParams": map[string]string{ "year": "numeric", "month": "numeric", "day": "numeric", }, }, ), ), true) return } err = dataprovider.AddShare(share, claims.Username, ipAddr, claims.Role) if err == nil { http.Redirect(w, r, webClientSharesPath, http.StatusSeeOther) } else { s.renderAddUpdateSharePage(w, r, share, util.NewI18nError(err, util.I18nErrorShareGeneric), true) } } func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } shareID := getURLParam(r, "id") share, err := dataprovider.ShareExists(shareID, claims.Username) if errors.Is(err, util.ErrNotFound) { s.renderClientNotFoundPage(w, r, err) return } else if err != nil { s.renderClientInternalServerErrorPage(w, r, err) return } updatedShare, err := getShareFromPostFields(r) if err != nil { s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError(err, util.I18nError500Message), false) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } updatedShare.ShareID = shareID updatedShare.Username = claims.Username if updatedShare.Password == redactedSecret { updatedShare.Password = share.Password } if updatedShare.Password == "" { if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError(util.NewValidationError("You are not allowed to share files/folders without password"), util.I18nErrorShareNoPwd), false) return } } user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError(err, util.I18nErrorGetUser), false) return } if err := user.CheckMaxShareExpiration(util.GetTimeFromMsecSinceEpoch(updatedShare.ExpiresAt)); err != nil { s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError( err, util.I18nErrorShareExpirationOutOfRange, util.I18nErrorArgs( map[string]any{ "val": time.Now().Add(24 * time.Hour * time.Duration(user.Filters.MaxSharesExpiration+1)).UnixMilli(), "formatParams": map[string]string{ "year": "numeric", "month": "numeric", "day": "numeric", }, }, ), ), false) return } err = dataprovider.UpdateShare(updatedShare, claims.Username, ipAddr, claims.Role) if err == nil { http.Redirect(w, r, webClientSharesPath, http.StatusSeeOther) } else { s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError(err, util.I18nErrorShareGeneric), false) } } func getAllShares(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) return } dataGetter := func(limit, offset int) ([]byte, int, error) { shares, err := dataprovider.GetShares(limit, offset, dataprovider.OrderASC, claims.Username) if err != nil { return nil, 0, err } data, err := json.Marshal(shares) return data, len(shares), err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func (s *httpdServer) handleClientGetShares(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := clientSharesPage{ baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, w, r), BasePublicSharesURL: webClientPubSharesPath, BaseURL: s.binding.BaseURL, } renderClientTemplate(w, templateClientShares, data) } func (s *httpdServer) handleClientGetProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderClientProfilePage(w, r, nil) } func (s *httpdServer) handleWebClientChangePwd(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderClientChangePasswordPage(w, r, nil) } func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http.Request) { //nolint:gocyclo r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) err := r.ParseForm() if err != nil { s.renderClientProfilePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } user, userMerged, err := dataprovider.GetUserVariants(claims.Username, "") if err != nil { s.renderClientProfilePage(w, r, util.NewI18nError(err, util.I18nErrorGetUser)) return } if !userMerged.CanUpdateProfile() { s.renderClientForbiddenPage(w, r, util.NewI18nError( errors.New("you are not allowed to change anything"), util.I18nErrorNoPermissions, )) return } if userMerged.CanManagePublicKeys() { for k := range r.Form { if hasPrefixAndSuffix(k, "public_keys[", "][public_key]") { r.Form.Add("public_keys", r.Form.Get(k)) } } user.PublicKeys = r.Form["public_keys"] } if userMerged.CanManageTLSCerts() { for k := range r.Form { if hasPrefixAndSuffix(k, "tls_certs[", "][tls_cert]") { r.Form.Add("tls_certs", r.Form.Get(k)) } } user.Filters.TLSCerts = r.Form["tls_certs"] } if userMerged.CanChangeAPIKeyAuth() { user.Filters.AllowAPIKeyAuth = r.Form.Get("allow_api_key_auth") != "" } if userMerged.CanChangeInfo() { user.Email = strings.TrimSpace(r.Form.Get("email")) user.Description = r.Form.Get("description") for k := range r.Form { if hasPrefixAndSuffix(k, "additional_emails[", "][additional_email]") { email := strings.TrimSpace(r.Form.Get(k)) if email != "" { r.Form.Add("additional_emails", email) } } } user.Filters.AdditionalEmails = r.Form["additional_emails"] } err = dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, ipAddr, user.Role) if err != nil { s.renderClientProfilePage(w, r, util.NewI18nError(err, util.I18nError500Message)) return } s.renderClientMessagePage(w, r, util.I18nProfileTitle, http.StatusOK, nil, util.I18nProfileUpdated) } func (s *httpdServer) handleWebClientMFA(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderClientMFAPage(w, r) } func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderClientTwoFactorPage(w, r, nil) } func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) s.renderClientTwoFactorRecoveryPage(w, r, nil) } func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) { share := &dataprovider.Share{} if err := r.ParseForm(); err != nil { return share, util.NewI18nError(err, util.I18nErrorInvalidForm) } for k := range r.Form { if hasPrefixAndSuffix(k, "paths[", "][path]") { r.Form.Add("paths", r.Form.Get(k)) } } share.Name = strings.TrimSpace(r.Form.Get("name")) share.Description = r.Form.Get("description") for _, p := range r.Form["paths"] { if strings.TrimSpace(p) != "" { share.Paths = append(share.Paths, p) } } share.Password = strings.TrimSpace(r.Form.Get("password")) share.AllowFrom = getSliceFromDelimitedValues(r.Form.Get("allowed_ip"), ",") scope, err := strconv.Atoi(r.Form.Get("scope")) if err != nil { return share, util.NewI18nError(err, util.I18nErrorShareScope) } share.Scope = dataprovider.ShareScope(scope) maxTokens, err := strconv.Atoi(r.Form.Get("max_tokens")) if err != nil { return share, util.NewI18nError(err, util.I18nErrorShareMaxTokens) } share.MaxTokens = maxTokens expirationDateMillis := int64(0) expirationDateString := strings.TrimSpace(r.Form.Get("expiration_date")) if expirationDateString != "" { expirationDate, err := time.Parse(webDateTimeFormat, expirationDateString) if err != nil { return share, util.NewI18nError(err, util.I18nErrorShareExpiration) } expirationDateMillis = util.GetTimeAsMsSinceEpoch(expirationDate) } share.ExpiresAt = expirationDateMillis return share, nil } func (s *httpdServer) handleWebClientForgotPwd(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) if !smtp.IsEnabled() { s.renderClientNotFoundPage(w, r, errors.New("this page does not exist")) return } s.renderClientForgotPwdPage(w, r, nil) } func (s *httpdServer) handleWebClientForgotPwdPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) err := r.ParseForm() if err != nil { s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } username := strings.TrimSpace(r.Form.Get("username")) err = handleForgotPassword(r, username, false) if err != nil { s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric)) return } http.Redirect(w, r, webClientResetPwdPath, http.StatusFound) } func (s *httpdServer) handleWebClientPasswordReset(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) if !smtp.IsEnabled() { s.renderClientNotFoundPage(w, r, errors.New("this page does not exist")) return } s.renderClientResetPwdPage(w, r, nil) } func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) name := r.URL.Query().Get("path") if name == "" { s.renderClientBadRequestPage(w, r, errors.New("no file specified")) return } name = util.CleanPath(name) data := viewPDFPage{ commonBasePage: getCommonBasePage(r), Title: path.Base(name), URL: fmt.Sprintf("%s?path=%s&_=%d", webClientGetPDFPath, url.QueryEscape(name), time.Now().UTC().Unix()), Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), } renderClientTemplate(w, templateClientViewPDF, data) } func (s *httpdServer) handleClientGetPDF(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } name := r.URL.Query().Get("path") if name == "" { s.renderClientBadRequestPage(w, r, util.NewI18nError(errors.New("no file specified"), util.I18nError400Message)) return } name = util.CleanPath(name) user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") if err != nil { s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), util.NewI18nError(err, util.I18nErrorGetUser), "") return } connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { s.renderClientForbiddenPage(w, r, err) return } baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") return } defer common.Connections.Remove(connection.GetID()) info, err := connection.Stat(name, 0) if err != nil { status := getRespStatus(err) s.renderClientMessagePage(w, r, util.I18nErrorPDFTitle, status, util.NewI18nError(err, i18nFsMsg(status)), "") return } if info.IsDir() { s.renderClientBadRequestPage(w, r, util.NewI18nError(fmt.Errorf("%q is not a file", name), util.I18nErrorPDFMessage)) return } connection.User.CheckFsRoot(connection.ID) //nolint:errcheck if err := s.ensurePDF(w, r, name, connection); err != nil { return } downloadFile(w, r, connection, name, info, true, nil) //nolint:errcheck } func (s *httpdServer) ensurePDF(w http.ResponseWriter, r *http.Request, name string, connection *Connection) error { reader, err := connection.getFileReader(name, 0, r.Method) if err != nil { s.renderClientMessagePage(w, r, util.I18nErrorPDFTitle, getRespStatus(err), util.NewI18nError(err, util.I18nError500Message), "") return err } defer reader.Close() var b bytes.Buffer _, err = io.CopyN(&b, reader, 128) if err != nil { s.renderClientMessagePage(w, r, util.I18nErrorPDFTitle, getRespStatus(err), util.NewI18nError(err, util.I18nErrorPDFMessage), "") return err } if ctype := http.DetectContentType(b.Bytes()); ctype != "application/pdf" { connection.Log(logger.LevelDebug, "detected %q content type, expected PDF, file %q", ctype, name) err := fmt.Errorf("the file %q does not look like a PDF", name) s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorPDFMessage)) return err } return nil } func (s *httpdServer) handleClientShareLoginGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) s.renderShareLoginPage(w, r, nil) } func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } invalidateToken(r) shareID := getURLParam(r, "id") share, err := dataprovider.ShareExists(shareID, "") if err != nil { s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) return } match, err := share.CheckCredentials(strings.TrimSpace(r.Form.Get("share_password"))) if !match || err != nil { s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } next := path.Clean(r.URL.Query().Get("next")) baseShareURL := path.Join(webClientPubSharesPath, share.ShareID) isRedirect, redirectTo := checkShareRedirectURL(next, baseShareURL) c := &jwt.Claims{ Username: shareID, } if isRedirect { c.Ref = next } err = createAndSetCookie(w, r, c, s.tokenAuth, tokenAudienceWebShare, ipAddr) if err != nil { s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message)) return } if isRedirect { http.Redirect(w, r, redirectTo, http.StatusFound) return } s.renderClientMessagePage(w, r, util.I18nSharedFilesTitle, http.StatusOK, nil, util.I18nShareLoginOK) } func (s *httpdServer) handleClientShareLogout(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) shareID := getURLParam(r, "id") ctx, claims, err := s.getShareClaims(r, shareID) if err != nil { s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, http.StatusForbidden, util.NewI18nError(err, util.I18nErrorInvalidToken), "") return } removeCookie(w, r.WithContext(ctx), webBaseClientPath) redirectURL := path.Join(webClientPubSharesPath, shareID, fmt.Sprintf("login?next=%s", url.QueryEscape(claims.Ref))) http.Redirect(w, r, redirectURL, http.StatusFound) } func (s *httpdServer) handleClientSharedFile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead} share, _, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } query := "" if r.URL.RawQuery != "" { query = "?" + r.URL.RawQuery } s.renderShareDownloadPage(w, r, &share, path.Join(webClientPubSharesPath, share.ShareID)+query) } func (s *httpdServer) handleClientCheckExist(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) doCheckExist(w, r, connection, name) } func (s *httpdServer) handleClientShareCheckExist(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeReadWrite} share, connection, err := s.checkPublicShare(w, r, validScopes) if err != nil { return } if err := validateBrowsableShare(share, connection); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } name, err := getBrowsableSharedPath(share.Paths[0], r) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) doCheckExist(w, r, connection, name) } type filesToCheck struct { Files []string `json:"files"` } func doCheckExist(w http.ResponseWriter, r *http.Request, connection *Connection, name string) { var filesList filesToCheck err := render.DecodeJSON(r.Body, &filesList) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) return } if len(filesList.Files) == 0 { sendAPIResponse(w, r, errors.New("files to be checked are mandatory"), "", http.StatusBadRequest) return } lister, err := connection.ListDir(name) if err != nil { sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err)) return } defer lister.Close() dataGetter := func(limit, _ int) ([]byte, int, error) { contents, err := lister.Next(limit) if errors.Is(err, io.EOF) { err = nil } if err != nil { return nil, 0, err } existing := make([]map[string]any, 0) for _, info := range contents { if slices.Contains(filesList.Files, info.Name()) { res := make(map[string]any) res["name"] = info.Name() if info.IsDir() { res["type"] = "1" res["size"] = "" } else { res["type"] = "2" res["size"] = info.Size() } existing = append(existing, res) } } data, err := json.Marshal(existing) count := limit if len(existing) == 0 { count = 0 } return data, count, err } streamJSONArray(w, defaultQueryLimit, dataGetter) } func checkShareRedirectURL(next, base string) (bool, string) { if !strings.HasPrefix(next, base) { return false, "" } if next == base { return true, path.Join(next, "download") } baseURL, err := url.Parse(base) if err != nil { return false, "" } nextURL, err := url.Parse(next) if err != nil { return false, "" } if nextURL.Path == baseURL.Path { redirectURL := nextURL.JoinPath("download") return true, redirectURL.String() } return true, next } func getWebTask(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } taskID := getURLParam(r, "id") task, err := webTaskMgr.Get(taskID) if err != nil { sendAPIResponse(w, r, err, "Unable to get task", getMappedStatusCode(err)) return } if task.User != claims.Username { sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } render.JSON(w, r, task) } func taskDeleteDir(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) task := webTaskData{ ID: connection.GetID(), User: connection.GetUsername(), Path: name, Timestamp: util.GetTimeAsMsSinceEpoch(time.Now()), Status: 0, } if err := webTaskMgr.Add(task); err != nil { common.Connections.Remove(connection.GetID()) sendAPIResponse(w, r, nil, "Unable to create task", http.StatusInternalServerError) return } go executeDeleteTask(connection, task) sendAPIResponse(w, r, nil, task.ID, http.StatusAccepted) } func taskRenameFsEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path")) newName := connection.User.GetCleanedPath(r.URL.Query().Get("target")) task := webTaskData{ ID: connection.GetID(), User: connection.GetUsername(), Path: oldName, Target: newName, Timestamp: util.GetTimeAsMsSinceEpoch(time.Now()), Status: 0, } if err := webTaskMgr.Add(task); err != nil { common.Connections.Remove(connection.GetID()) sendAPIResponse(w, r, nil, "Unable to create task", http.StatusInternalServerError) return } go executeRenameTask(connection, task) sendAPIResponse(w, r, nil, task.ID, http.StatusAccepted) } func taskCopyFsEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) connection, err := getUserConnection(w, r) if err != nil { return } source := r.URL.Query().Get("path") target := r.URL.Query().Get("target") copyFromSource := strings.HasSuffix(source, "/") copyInTarget := strings.HasSuffix(target, "/") source = connection.User.GetCleanedPath(source) target = connection.User.GetCleanedPath(target) if copyFromSource { source += "/" } if copyInTarget { target += "/" } task := webTaskData{ ID: connection.GetID(), User: connection.GetUsername(), Path: source, Target: target, Timestamp: util.GetTimeAsMsSinceEpoch(time.Now()), Status: 0, } if err := webTaskMgr.Add(task); err != nil { common.Connections.Remove(connection.GetID()) sendAPIResponse(w, r, nil, "Unable to create task", http.StatusInternalServerError) return } go executeCopyTask(connection, task) sendAPIResponse(w, r, nil, task.ID, http.StatusAccepted) } func executeDeleteTask(conn *Connection, task webTaskData) { done := make(chan bool) defer func() { close(done) common.Connections.Remove(conn.GetID()) }() go keepAliveTask(task, done, 2*time.Minute) status := http.StatusOK if err := conn.RemoveAll(task.Path); err != nil { status = getMappedStatusCode(err) } task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) task.Status = status err := webTaskMgr.Add(task) conn.Log(logger.LevelDebug, "delete task finished, status: %d, update task err: %v", status, err) } func executeRenameTask(conn *Connection, task webTaskData) { done := make(chan bool) defer func() { close(done) common.Connections.Remove(conn.GetID()) }() go keepAliveTask(task, done, 2*time.Minute) status := http.StatusOK if !conn.IsSameResource(task.Path, task.Target) { if err := conn.Copy(task.Path, task.Target); err != nil { status = getMappedStatusCode(err) task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) task.Status = status err = webTaskMgr.Add(task) conn.Log(logger.LevelDebug, "copy step for rename task finished, status: %d, update task err: %v", status, err) return } if err := conn.RemoveAll(task.Path); err != nil { status = getMappedStatusCode(err) } } else { if err := conn.Rename(task.Path, task.Target); err != nil { status = getMappedStatusCode(err) } } task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) task.Status = status err := webTaskMgr.Add(task) conn.Log(logger.LevelDebug, "rename task finished, status: %d, update task err: %v", status, err) } func executeCopyTask(conn *Connection, task webTaskData) { done := make(chan bool) defer func() { close(done) common.Connections.Remove(conn.GetID()) }() go keepAliveTask(task, done, 2*time.Minute) status := http.StatusOK if err := conn.Copy(task.Path, task.Target); err != nil { status = getMappedStatusCode(err) } task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) task.Status = status err := webTaskMgr.Add(task) conn.Log(logger.LevelDebug, "copy task finished, status: %d, update task err: %v", status, err) } func keepAliveTask(task webTaskData, done chan bool, interval time.Duration) { ticker := time.NewTicker(interval) defer func() { ticker.Stop() }() for { select { case <-done: return case <-ticker.C: task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) err := webTaskMgr.Add(task) logger.Debug(logSender, task.ID, "task timestamp updated, err: %v", err) } } } ================================================ FILE: internal/httpd/webtask.go ================================================ // Copyright (C) 2024 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "encoding/json" "fmt" "sync" "time" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( webTaskMgr webTaskManager ) func newWebTaskManager(isShared int) webTaskManager { if isShared == 1 { logger.Info(logSender, "", "using provider task manager") return &dbTaskManager{} } logger.Info(logSender, "", "using memory task manager") return &memoryTaskManager{} } type webTaskManager interface { Add(data webTaskData) error Get(ID string) (webTaskData, error) Cleanup() } type webTaskData struct { ID string `json:"id"` User string `json:"user"` Path string `json:"path"` Target string `json:"target"` Timestamp int64 `json:"ts"` Status int `json:"status"` // 0 in progress or http status code (200 ok, 403 and so on) } type memoryTaskManager struct { tasks sync.Map } func (m *memoryTaskManager) Add(data webTaskData) error { m.tasks.Store(data.ID, &data) return nil } func (m *memoryTaskManager) Get(ID string) (webTaskData, error) { data, ok := m.tasks.Load(ID) if !ok { return webTaskData{}, util.NewRecordNotFoundError(fmt.Sprintf("task for ID %q not found", ID)) } return *data.(*webTaskData), nil } func (m *memoryTaskManager) Cleanup() { m.tasks.Range(func(key, value any) bool { data := value.(*webTaskData) if data.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now().Add(-5*time.Minute)) { m.tasks.Delete(key) } return true }) } type dbTaskManager struct{} func (m *dbTaskManager) Add(data webTaskData) error { session := dataprovider.Session{ Key: data.ID, Data: data, Type: dataprovider.SessionTypeWebTask, Timestamp: data.Timestamp, } return dataprovider.AddSharedSession(session) } func (m *dbTaskManager) Get(ID string) (webTaskData, error) { sess, err := dataprovider.GetSharedSession(ID, dataprovider.SessionTypeWebTask) if err != nil { return webTaskData{}, err } d := sess.Data.([]byte) var data webTaskData err = json.Unmarshal(d, &data) return data, err } func (m *dbTaskManager) Cleanup() { dataprovider.CleanupSharedSessions(dataprovider.SessionTypeWebTask, time.Now().Add(-5*time.Minute)) //nolint:errcheck } ================================================ FILE: internal/httpd/webtask_test.go ================================================ // Copyright (C) 2024 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpd import ( "testing" "time" "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/util" ) func TestMemoryWebTaskManager(t *testing.T) { mgr := newWebTaskManager(0) m, ok := mgr.(*memoryTaskManager) require.True(t, ok) task := webTaskData{ ID: xid.New().String(), User: defeaultUsername, Timestamp: time.Now().Add(-1 * time.Hour).UnixMilli(), Status: 0, } task1 := webTaskData{ ID: xid.New().String(), User: defeaultUsername, Timestamp: time.Now().UnixMilli(), Status: 0, } err := m.Add(task) require.NoError(t, err) err = m.Add(task1) require.NoError(t, err) taskGet, err := m.Get(task.ID) require.NoError(t, err) require.Equal(t, task, taskGet) m.Cleanup() _, err = m.Get(task.ID) require.ErrorIs(t, err, util.ErrNotFound) taskGet, err = m.Get(task1.ID) require.NoError(t, err) require.Equal(t, task1, taskGet) task1.Timestamp = time.Now().Add(-1 * time.Hour).UnixMilli() err = m.Add(task1) require.NoError(t, err) m.Cleanup() _, err = m.Get(task.ID) require.ErrorIs(t, err, util.ErrNotFound) // test keep alive task oldMgr := webTaskMgr webTaskMgr = mgr done := make(chan bool) go keepAliveTask(task, done, 50*time.Millisecond) time.Sleep(120 * time.Millisecond) close(done) taskGet, err = m.Get(task.ID) require.NoError(t, err) assert.Greater(t, taskGet.Timestamp, task.Timestamp) m.Cleanup() _, err = m.Get(task.ID) require.NoError(t, err) err = m.Add(task) require.NoError(t, err) m.Cleanup() _, err = m.Get(task.ID) require.ErrorIs(t, err, util.ErrNotFound) webTaskMgr = oldMgr } func TestDbWebTaskManager(t *testing.T) { if !isSharedProviderSupported() { t.Skip("this test it is not available with this provider") } mgr := newWebTaskManager(1) m, ok := mgr.(*dbTaskManager) require.True(t, ok) task := webTaskData{ ID: xid.New().String(), User: defeaultUsername, Timestamp: time.Now().Add(-1 * time.Hour).UnixMilli(), Status: 0, } err := m.Add(task) require.NoError(t, err) taskGet, err := m.Get(task.ID) require.NoError(t, err) require.Equal(t, task, taskGet) m.Cleanup() _, err = m.Get(task.ID) require.ErrorIs(t, err, util.ErrNotFound) err = m.Add(task) require.NoError(t, err) // test keep alive task oldMgr := webTaskMgr webTaskMgr = mgr done := make(chan bool) go keepAliveTask(task, done, 50*time.Millisecond) time.Sleep(120 * time.Millisecond) close(done) taskGet, err = m.Get(task.ID) require.NoError(t, err) assert.Greater(t, taskGet.Timestamp, task.Timestamp) m.Cleanup() _, err = m.Get(task.ID) require.NoError(t, err) err = m.Add(task) require.NoError(t, err) m.Cleanup() _, err = m.Get(task.ID) require.ErrorIs(t, err, util.ErrNotFound) webTaskMgr = oldMgr } ================================================ FILE: internal/httpdtest/httpdtest.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package httpdtest provides utilities for testing the supported REST API. package httpdtest import ( "bytes" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "path" "slices" "strconv" "strings" "github.com/go-chi/render" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( tokenPath = "/api/v2/token" activeConnectionsPath = "/api/v2/connections" quotasBasePath = "/api/v2/quotas" quotaScanPath = "/api/v2/quotas/users/scans" quotaScanVFolderPath = "/api/v2/quotas/folders/scans" userPath = "/api/v2/users" groupPath = "/api/v2/groups" versionPath = "/api/v2/version" folderPath = "/api/v2/folders" serverStatusPath = "/api/v2/status" dumpDataPath = "/api/v2/dumpdata" loadDataPath = "/api/v2/loaddata" defenderHosts = "/api/v2/defender/hosts" adminPath = "/api/v2/admins" adminPwdPath = "/api/v2/admin/changepwd" apiKeysPath = "/api/v2/apikeys" retentionChecksPath = "/api/v2/retention/users/checks" eventActionsPath = "/api/v2/eventactions" eventRulesPath = "/api/v2/eventrules" rolesPath = "/api/v2/roles" ipListsPath = "/api/v2/iplists" ) const ( defaultTokenAuthUser = "admin" defaultTokenAuthPass = "password" ) var ( httpBaseURL = "http://127.0.0.1:8080" jwtToken = "" ) // SetBaseURL sets the base url to use for HTTP requests. // Default URL is "http://127.0.0.1:8080" func SetBaseURL(url string) { httpBaseURL = url } // SetJWTToken sets the JWT token to use func SetJWTToken(token string) { jwtToken = token } func sendHTTPRequest(method, url string, body io.Reader, contentType, token string) (*http.Response, error) { req, err := http.NewRequest(method, url, body) if err != nil { return nil, err } if contentType != "" { req.Header.Set("Content-Type", "application/json") } if token != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token)) } return httpclient.GetHTTPClient().Do(req) } func buildURLRelativeToBase(paths ...string) string { // we need to use path.Join and not filepath.Join // since filepath.Join will use backslash separator on Windows p := path.Join(paths...) return fmt.Sprintf("%s/%s", strings.TrimRight(httpBaseURL, "/"), strings.TrimLeft(p, "/")) } // GetToken tries to return a JWT token func GetToken(username, password string) (string, map[string]any, error) { req, err := http.NewRequest(http.MethodGet, buildURLRelativeToBase(tokenPath), nil) if err != nil { return "", nil, err } req.SetBasicAuth(username, password) resp, err := httpclient.GetHTTPClient().Do(req) if err != nil { return "", nil, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, http.StatusOK) if err != nil { return "", nil, err } responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) if err != nil { return "", nil, err } return responseHolder["access_token"].(string), responseHolder, nil } func getDefaultToken() string { if jwtToken != "" { return jwtToken } token, _, err := GetToken(defaultTokenAuthUser, defaultTokenAuthPass) if err != nil { return "" } return token } // AddUser adds a new user and checks the received HTTP Status code against expectedStatusCode. func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) { var newUser dataprovider.User var body []byte userAsJSON, _ := json.Marshal(user) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(userPath), bytes.NewBuffer(userAsJSON), "application/json", getDefaultToken()) if err != nil { return newUser, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newUser, body, err } if err == nil { err = render.DecodeJSON(resp.Body, &newUser) } else { body, _ = getResponseBody(resp) } if err == nil { err = checkUser(&user, &newUser) } return newUser, body, err } // UpdateUserWithJSON update a user using the provided JSON as POST body func UpdateUserWithJSON(user dataprovider.User, expectedStatusCode int, disconnect string, userAsJSON []byte) (dataprovider.User, []byte, error) { var newUser dataprovider.User var body []byte url, err := addUpdateUserQueryParams(buildURLRelativeToBase(userPath, url.PathEscape(user.Username)), disconnect) if err != nil { return user, body, err } resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", getDefaultToken()) if err != nil { return user, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return newUser, body, err } if err == nil { newUser, body, err = GetUserByUsername(user.Username, expectedStatusCode) } if err == nil { err = checkUser(&user, &newUser) } return newUser, body, err } // UpdateUser updates an existing user and checks the received HTTP Status code against expectedStatusCode. func UpdateUser(user dataprovider.User, expectedStatusCode int, disconnect string) (dataprovider.User, []byte, error) { userAsJSON, _ := json.Marshal(user) return UpdateUserWithJSON(user, expectedStatusCode, disconnect, userAsJSON) } // RemoveUser removes an existing user and checks the received HTTP Status code against expectedStatusCode. func RemoveUser(user dataprovider.User, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(userPath, url.PathEscape(user.Username)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetUserByUsername gets a user by username and checks the received HTTP Status code against expectedStatusCode. func GetUserByUsername(username string, expectedStatusCode int) (dataprovider.User, []byte, error) { var user dataprovider.User var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(userPath, url.PathEscape(username)), nil, "", getDefaultToken()) if err != nil { return user, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &user) } else { body, _ = getResponseBody(resp) } return user, body, err } // GetUsers returns a list of users and checks the received HTTP Status code against expectedStatusCode. // The number of results can be limited specifying a limit. // Some results can be skipped specifying an offset. func GetUsers(limit, offset int64, expectedStatusCode int) ([]dataprovider.User, []byte, error) { var users []dataprovider.User var body []byte url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(userPath), limit, offset) if err != nil { return users, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return users, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &users) } else { body, _ = getResponseBody(resp) } return users, body, err } // AddGroup adds a new group and checks the received HTTP Status code against expectedStatusCode. func AddGroup(group dataprovider.Group, expectedStatusCode int) (dataprovider.Group, []byte, error) { var newGroup dataprovider.Group var body []byte asJSON, _ := json.Marshal(group) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(groupPath), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newGroup, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newGroup, body, err } if err == nil { err = render.DecodeJSON(resp.Body, &newGroup) } else { body, _ = getResponseBody(resp) } if err == nil { group.UserSettings.Filters.TLSCerts = nil err = checkGroup(group, newGroup) } return newGroup, body, err } // UpdateGroup updates an existing group and checks the received HTTP Status code against expectedStatusCode func UpdateGroup(group dataprovider.Group, expectedStatusCode int) (dataprovider.Group, []byte, error) { var newGroup dataprovider.Group var body []byte asJSON, _ := json.Marshal(group) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(groupPath, url.PathEscape(group.Name)), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newGroup, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return newGroup, body, err } if err == nil { newGroup, body, err = GetGroupByName(group.Name, expectedStatusCode) } if err == nil { err = checkGroup(group, newGroup) } return newGroup, body, err } // RemoveGroup removes an existing group and checks the received HTTP Status code against expectedStatusCode. func RemoveGroup(group dataprovider.Group, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(groupPath, url.PathEscape(group.Name)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetGroupByName gets a group by name and checks the received HTTP Status code against expectedStatusCode. func GetGroupByName(name string, expectedStatusCode int) (dataprovider.Group, []byte, error) { var group dataprovider.Group var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(groupPath, url.PathEscape(name)), nil, "", getDefaultToken()) if err != nil { return group, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &group) } else { body, _ = getResponseBody(resp) } return group, body, err } // GetGroups returns a list of groups and checks the received HTTP Status code against expectedStatusCode. // The number of results can be limited specifying a limit. // Some results can be skipped specifying an offset. func GetGroups(limit, offset int64, expectedStatusCode int) ([]dataprovider.Group, []byte, error) { var groups []dataprovider.Group var body []byte url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(groupPath), limit, offset) if err != nil { return groups, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return groups, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &groups) } else { body, _ = getResponseBody(resp) } return groups, body, err } // AddRole adds a new role and checks the received HTTP Status code against expectedStatusCode. func AddRole(role dataprovider.Role, expectedStatusCode int) (dataprovider.Role, []byte, error) { var newRole dataprovider.Role var body []byte asJSON, _ := json.Marshal(role) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(rolesPath), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newRole, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newRole, body, err } if err == nil { err = render.DecodeJSON(resp.Body, &newRole) } else { body, _ = getResponseBody(resp) } if err == nil { err = checkRole(role, newRole) } return newRole, body, err } // UpdateRole updates an existing role and checks the received HTTP Status code against expectedStatusCode func UpdateRole(role dataprovider.Role, expectedStatusCode int) (dataprovider.Role, []byte, error) { var newRole dataprovider.Role var body []byte asJSON, _ := json.Marshal(role) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(rolesPath, url.PathEscape(role.Name)), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newRole, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return newRole, body, err } if err == nil { newRole, body, err = GetRoleByName(role.Name, expectedStatusCode) } if err == nil { err = checkRole(role, newRole) } return newRole, body, err } // RemoveRole removes an existing role and checks the received HTTP Status code against expectedStatusCode. func RemoveRole(role dataprovider.Role, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(rolesPath, url.PathEscape(role.Name)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetRoleByName gets a role by name and checks the received HTTP Status code against expectedStatusCode. func GetRoleByName(name string, expectedStatusCode int) (dataprovider.Role, []byte, error) { var role dataprovider.Role var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(rolesPath, url.PathEscape(name)), nil, "", getDefaultToken()) if err != nil { return role, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &role) } else { body, _ = getResponseBody(resp) } return role, body, err } // GetRoles returns a list of roles and checks the received HTTP Status code against expectedStatusCode. // The number of results can be limited specifying a limit. // Some results can be skipped specifying an offset. func GetRoles(limit, offset int64, expectedStatusCode int) ([]dataprovider.Role, []byte, error) { var roles []dataprovider.Role var body []byte url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(rolesPath), limit, offset) if err != nil { return roles, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return roles, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &roles) } else { body, _ = getResponseBody(resp) } return roles, body, err } // AddIPListEntry adds a new IP list entry and checks the received HTTP Status code against expectedStatusCode. func AddIPListEntry(entry dataprovider.IPListEntry, expectedStatusCode int) (dataprovider.IPListEntry, []byte, error) { var newEntry dataprovider.IPListEntry var body []byte asJSON, _ := json.Marshal(entry) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(ipListsPath, strconv.Itoa(int(entry.Type))), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newEntry, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newEntry, body, err } if err == nil { newEntry, body, err = GetIPListEntry(entry.IPOrNet, entry.Type, http.StatusOK) } if err == nil { err = checkIPListEntry(entry, newEntry) } return newEntry, body, err } // UpdateIPListEntry updates an existing IP list entry and checks the received HTTP Status code against expectedStatusCode func UpdateIPListEntry(entry dataprovider.IPListEntry, expectedStatusCode int) (dataprovider.IPListEntry, []byte, error) { var newEntry dataprovider.IPListEntry var body []byte asJSON, _ := json.Marshal(entry) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(ipListsPath, fmt.Sprintf("%d", entry.Type), url.PathEscape(entry.IPOrNet)), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newEntry, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return newEntry, body, err } if err == nil { newEntry, body, err = GetIPListEntry(entry.IPOrNet, entry.Type, http.StatusOK) } if err == nil { err = checkIPListEntry(entry, newEntry) } return newEntry, body, err } // RemoveIPListEntry removes an existing IP list entry and checks the received HTTP Status code against expectedStatusCode. func RemoveIPListEntry(entry dataprovider.IPListEntry, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(ipListsPath, fmt.Sprintf("%d", entry.Type), url.PathEscape(entry.IPOrNet)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetIPListEntry returns an IP list entry matching the specified parameters, if exists, // and checks the received HTTP Status code against expectedStatusCode. func GetIPListEntry(ipOrNet string, listType dataprovider.IPListType, expectedStatusCode int, ) (dataprovider.IPListEntry, []byte, error) { var entry dataprovider.IPListEntry var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(ipListsPath, fmt.Sprintf("%d", listType), url.PathEscape(ipOrNet)), nil, "", getDefaultToken()) if err != nil { return entry, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &entry) } else { body, _ = getResponseBody(resp) } return entry, body, err } // GetIPListEntries returns a list of IP list entries and checks the received HTTP Status code against expectedStatusCode. func GetIPListEntries(listType dataprovider.IPListType, filter, from, order string, limit int64, expectedStatusCode int, ) ([]dataprovider.IPListEntry, []byte, error) { var entries []dataprovider.IPListEntry var body []byte url, err := url.Parse(buildURLRelativeToBase(ipListsPath, strconv.Itoa(int(listType)))) if err != nil { return entries, body, err } q := url.Query() q.Add("filter", filter) q.Add("from", from) q.Add("order", order) if limit > 0 { q.Add("limit", strconv.FormatInt(limit, 10)) } url.RawQuery = q.Encode() resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return entries, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &entries) } else { body, _ = getResponseBody(resp) } return entries, body, err } // AddAdmin adds a new admin and checks the received HTTP Status code against expectedStatusCode. func AddAdmin(admin dataprovider.Admin, expectedStatusCode int) (dataprovider.Admin, []byte, error) { var newAdmin dataprovider.Admin var body []byte asJSON, _ := json.Marshal(admin) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(adminPath), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newAdmin, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newAdmin, body, err } if err == nil { err = render.DecodeJSON(resp.Body, &newAdmin) } else { body, _ = getResponseBody(resp) } if err == nil { err = checkAdmin(&admin, &newAdmin) } return newAdmin, body, err } // UpdateAdmin updates an existing admin and checks the received HTTP Status code against expectedStatusCode func UpdateAdmin(admin dataprovider.Admin, expectedStatusCode int) (dataprovider.Admin, []byte, error) { var newAdmin dataprovider.Admin var body []byte asJSON, _ := json.Marshal(admin) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(adminPath, url.PathEscape(admin.Username)), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newAdmin, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return newAdmin, body, err } if err == nil { newAdmin, body, err = GetAdminByUsername(admin.Username, expectedStatusCode) } if err == nil { err = checkAdmin(&admin, &newAdmin) } return newAdmin, body, err } // RemoveAdmin removes an existing admin and checks the received HTTP Status code against expectedStatusCode. func RemoveAdmin(admin dataprovider.Admin, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(adminPath, url.PathEscape(admin.Username)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetAdminByUsername gets an admin by username and checks the received HTTP Status code against expectedStatusCode. func GetAdminByUsername(username string, expectedStatusCode int) (dataprovider.Admin, []byte, error) { var admin dataprovider.Admin var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(adminPath, url.PathEscape(username)), nil, "", getDefaultToken()) if err != nil { return admin, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &admin) } else { body, _ = getResponseBody(resp) } return admin, body, err } // GetAdmins returns a list of admins and checks the received HTTP Status code against expectedStatusCode. // The number of results can be limited specifying a limit. // Some results can be skipped specifying an offset. func GetAdmins(limit, offset int64, expectedStatusCode int) ([]dataprovider.Admin, []byte, error) { var admins []dataprovider.Admin var body []byte url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(adminPath), limit, offset) if err != nil { return admins, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return admins, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &admins) } else { body, _ = getResponseBody(resp) } return admins, body, err } // ChangeAdminPassword changes the password for an existing admin func ChangeAdminPassword(currentPassword, newPassword string, expectedStatusCode int) ([]byte, error) { var body []byte pwdChange := make(map[string]string) pwdChange["current_password"] = currentPassword pwdChange["new_password"] = newPassword asJSON, _ := json.Marshal(&pwdChange) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(adminPwdPath), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) body, _ = getResponseBody(resp) return body, err } // GetAPIKeys returns a list of API keys and checks the received HTTP Status code against expectedStatusCode. // The number of results can be limited specifying a limit. // Some results can be skipped specifying an offset. func GetAPIKeys(limit, offset int64, expectedStatusCode int) ([]dataprovider.APIKey, []byte, error) { var apiKeys []dataprovider.APIKey var body []byte url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(apiKeysPath), limit, offset) if err != nil { return apiKeys, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return apiKeys, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &apiKeys) } else { body, _ = getResponseBody(resp) } return apiKeys, body, err } // AddAPIKey adds a new API key and checks the received HTTP Status code against expectedStatusCode. func AddAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { var newAPIKey dataprovider.APIKey var body []byte asJSON, _ := json.Marshal(apiKey) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(apiKeysPath), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newAPIKey, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newAPIKey, body, err } if err != nil { body, _ = getResponseBody(resp) return newAPIKey, body, err } response := make(map[string]string) err = render.DecodeJSON(resp.Body, &response) if err == nil { newAPIKey, body, err = GetAPIKeyByID(resp.Header.Get("X-Object-ID"), http.StatusOK) } if err == nil { err = checkAPIKey(&apiKey, &newAPIKey) } newAPIKey.Key = response["key"] return newAPIKey, body, err } // UpdateAPIKey updates an existing API key and checks the received HTTP Status code against expectedStatusCode func UpdateAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { var newAPIKey dataprovider.APIKey var body []byte asJSON, _ := json.Marshal(apiKey) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(apiKeysPath, url.PathEscape(apiKey.KeyID)), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newAPIKey, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return newAPIKey, body, err } if err == nil { newAPIKey, body, err = GetAPIKeyByID(apiKey.KeyID, expectedStatusCode) } if err == nil { err = checkAPIKey(&apiKey, &newAPIKey) } return newAPIKey, body, err } // RemoveAPIKey removes an existing API key and checks the received HTTP Status code against expectedStatusCode. func RemoveAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(apiKeysPath, url.PathEscape(apiKey.KeyID)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetAPIKeyByID gets a API key by ID and checks the received HTTP Status code against expectedStatusCode. func GetAPIKeyByID(keyID string, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { var apiKey dataprovider.APIKey var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(apiKeysPath, url.PathEscape(keyID)), nil, "", getDefaultToken()) if err != nil { return apiKey, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &apiKey) } else { body, _ = getResponseBody(resp) } return apiKey, body, err } // AddEventAction adds a new event action func AddEventAction(action dataprovider.BaseEventAction, expectedStatusCode int) (dataprovider.BaseEventAction, []byte, error) { var newAction dataprovider.BaseEventAction var body []byte asJSON, _ := json.Marshal(action) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(eventActionsPath), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newAction, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newAction, body, err } if err == nil { err = render.DecodeJSON(resp.Body, &newAction) } else { body, _ = getResponseBody(resp) } if err == nil { err = checkEventAction(action, newAction) } return newAction, body, err } // UpdateEventAction updates an existing event action func UpdateEventAction(action dataprovider.BaseEventAction, expectedStatusCode int) (dataprovider.BaseEventAction, []byte, error) { var newAction dataprovider.BaseEventAction var body []byte asJSON, _ := json.Marshal(action) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(eventActionsPath, url.PathEscape(action.Name)), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newAction, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return newAction, body, err } if err == nil { newAction, body, err = GetEventActionByName(action.Name, expectedStatusCode) } if err == nil { err = checkEventAction(action, newAction) } return newAction, body, err } // RemoveEventAction removes an existing action and checks the received HTTP Status code against expectedStatusCode. func RemoveEventAction(action dataprovider.BaseEventAction, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(eventActionsPath, url.PathEscape(action.Name)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetEventActionByName gets an event action by name and checks the received HTTP Status code against expectedStatusCode. func GetEventActionByName(name string, expectedStatusCode int) (dataprovider.BaseEventAction, []byte, error) { var action dataprovider.BaseEventAction var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(eventActionsPath, url.PathEscape(name)), nil, "", getDefaultToken()) if err != nil { return action, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &action) } else { body, _ = getResponseBody(resp) } return action, body, err } // GetEventActions returns a list of event actions and checks the received HTTP Status code against expectedStatusCode. // The number of results can be limited specifying a limit. // Some results can be skipped specifying an offset. func GetEventActions(limit, offset int64, expectedStatusCode int) ([]dataprovider.BaseEventAction, []byte, error) { var actions []dataprovider.BaseEventAction var body []byte url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(eventActionsPath), limit, offset) if err != nil { return actions, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return actions, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &actions) } else { body, _ = getResponseBody(resp) } return actions, body, err } // AddEventRule adds a new event rule func AddEventRule(rule dataprovider.EventRule, expectedStatusCode int) (dataprovider.EventRule, []byte, error) { var newRule dataprovider.EventRule var body []byte asJSON, _ := json.Marshal(rule) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(eventRulesPath), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newRule, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newRule, body, err } if err == nil { err = render.DecodeJSON(resp.Body, &newRule) } else { body, _ = getResponseBody(resp) } if err == nil { err = checkEventRule(rule, newRule) } return newRule, body, err } // UpdateEventRule updates an existing event rule func UpdateEventRule(rule dataprovider.EventRule, expectedStatusCode int) (dataprovider.EventRule, []byte, error) { var newRule dataprovider.EventRule var body []byte asJSON, _ := json.Marshal(rule) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(eventRulesPath, url.PathEscape(rule.Name)), bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) if err != nil { return newRule, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return newRule, body, err } if err == nil { newRule, body, err = GetEventRuleByName(rule.Name, expectedStatusCode) } if err == nil { err = checkEventRule(rule, newRule) } return newRule, body, err } // RemoveEventRule removes an existing rule and checks the received HTTP Status code against expectedStatusCode. func RemoveEventRule(rule dataprovider.EventRule, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(eventRulesPath, url.PathEscape(rule.Name)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetEventRuleByName gets an event rule by name and checks the received HTTP Status code against expectedStatusCode. func GetEventRuleByName(name string, expectedStatusCode int) (dataprovider.EventRule, []byte, error) { var rule dataprovider.EventRule var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(eventRulesPath, url.PathEscape(name)), nil, "", getDefaultToken()) if err != nil { return rule, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &rule) } else { body, _ = getResponseBody(resp) } return rule, body, err } // GetEventRules returns a list of event rules and checks the received HTTP Status code against expectedStatusCode. // The number of results can be limited specifying a limit. // Some results can be skipped specifying an offset. func GetEventRules(limit, offset int64, expectedStatusCode int) ([]dataprovider.EventRule, []byte, error) { var rules []dataprovider.EventRule var body []byte url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(eventRulesPath), limit, offset) if err != nil { return rules, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return rules, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &rules) } else { body, _ = getResponseBody(resp) } return rules, body, err } // RunOnDemandRule executes the specified on demand rule func RunOnDemandRule(name string, expectedStatusCode int) ([]byte, error) { resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(eventRulesPath, "run", url.PathEscape(name)), nil, "application/json", getDefaultToken()) if err != nil { return nil, err } defer resp.Body.Close() b, err := getResponseBody(resp) if err != nil { return b, err } if err := checkResponse(resp.StatusCode, expectedStatusCode); err != nil { return b, err } return b, nil } // GetQuotaScans gets active quota scans for users and checks the received HTTP Status code against expectedStatusCode. func GetQuotaScans(expectedStatusCode int) ([]common.ActiveQuotaScan, []byte, error) { var quotaScans []common.ActiveQuotaScan var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanPath), nil, "", getDefaultToken()) if err != nil { return quotaScans, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, "aScans) } else { body, _ = getResponseBody(resp) } return quotaScans, body, err } // StartQuotaScan starts a new quota scan for the given user and checks the received HTTP Status code against expectedStatusCode. func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotasBasePath, "users", user.Username, "scan"), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // UpdateQuotaUsage updates the user used quota limits and checks the received // HTTP Status code against expectedStatusCode. func UpdateQuotaUsage(user dataprovider.User, mode string, expectedStatusCode int) ([]byte, error) { var body []byte userAsJSON, _ := json.Marshal(user) url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "users", user.Username, "usage"), mode) if err != nil { return body, err } resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // UpdateTransferQuotaUsage updates the user used transfer quota limits and checks the received // HTTP Status code against expectedStatusCode. func UpdateTransferQuotaUsage(user dataprovider.User, mode string, expectedStatusCode int) ([]byte, error) { var body []byte userAsJSON, _ := json.Marshal(user) url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "users", user.Username, "transfer-usage"), mode) if err != nil { return body, err } resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetRetentionChecks returns the active retention checks func GetRetentionChecks(expectedStatusCode int) ([]common.ActiveRetentionChecks, []byte, error) { var checks []common.ActiveRetentionChecks var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(retentionChecksPath), nil, "", getDefaultToken()) if err != nil { return checks, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &checks) } else { body, _ = getResponseBody(resp) } return checks, body, err } // GetConnections returns status and stats for active SFTP/SCP connections func GetConnections(expectedStatusCode int) ([]common.ConnectionStatus, []byte, error) { var connections []common.ConnectionStatus var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(activeConnectionsPath), nil, "", getDefaultToken()) if err != nil { return connections, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &connections) } else { body, _ = getResponseBody(resp) } return connections, body, err } // CloseConnection closes an active connection identified by connectionID func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) body, _ = getResponseBody(resp) return body, err } // AddFolder adds a new folder and checks the received HTTP Status code against expectedStatusCode func AddFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { var newFolder vfs.BaseVirtualFolder var body []byte folderAsJSON, _ := json.Marshal(folder) resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(folderPath), bytes.NewBuffer(folderAsJSON), "application/json", getDefaultToken()) if err != nil { return newFolder, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusCreated { body, _ = getResponseBody(resp) return newFolder, body, err } if err == nil { err = render.DecodeJSON(resp.Body, &newFolder) } else { body, _ = getResponseBody(resp) } if err == nil { err = checkFolder(&folder, &newFolder) } return newFolder, body, err } // UpdateFolder updates an existing folder and checks the received HTTP Status code against expectedStatusCode. func UpdateFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { var updatedFolder vfs.BaseVirtualFolder var body []byte folderAsJSON, _ := json.Marshal(folder) resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(folderPath, url.PathEscape(folder.Name)), bytes.NewBuffer(folderAsJSON), "application/json", getDefaultToken()) if err != nil { return updatedFolder, body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) err = checkResponse(resp.StatusCode, expectedStatusCode) if expectedStatusCode != http.StatusOK { return updatedFolder, body, err } if err == nil { updatedFolder, body, err = GetFolderByName(folder.Name, expectedStatusCode) } if err == nil { err = checkFolder(&folder, &updatedFolder) } return updatedFolder, body, err } // RemoveFolder removes an existing user and checks the received HTTP Status code against expectedStatusCode. func RemoveFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(folderPath, url.PathEscape(folder.Name)), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetFolderByName gets a folder by name and checks the received HTTP Status code against expectedStatusCode. func GetFolderByName(name string, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { var folder vfs.BaseVirtualFolder var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(folderPath, url.PathEscape(name)), nil, "", getDefaultToken()) if err != nil { return folder, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &folder) } else { body, _ = getResponseBody(resp) } return folder, body, err } // GetFolders returns a list of folders and checks the received HTTP Status code against expectedStatusCode. // The number of results can be limited specifying a limit. // Some results can be skipped specifying an offset. // The results can be filtered specifying a folder path, the folder path filter is an exact match func GetFolders(limit int64, offset int64, expectedStatusCode int) ([]vfs.BaseVirtualFolder, []byte, error) { var folders []vfs.BaseVirtualFolder var body []byte url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(folderPath), limit, offset) if err != nil { return folders, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return folders, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &folders) } else { body, _ = getResponseBody(resp) } return folders, body, err } // GetFoldersQuotaScans gets active quota scans for folders and checks the received HTTP Status code against expectedStatusCode. func GetFoldersQuotaScans(expectedStatusCode int) ([]common.ActiveVirtualFolderQuotaScan, []byte, error) { var quotaScans []common.ActiveVirtualFolderQuotaScan var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanVFolderPath), nil, "", getDefaultToken()) if err != nil { return quotaScans, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, "aScans) } else { body, _ = getResponseBody(resp) } return quotaScans, body, err } // StartFolderQuotaScan start a new quota scan for the given folder and checks the received HTTP Status code against expectedStatusCode. func StartFolderQuotaScan(folder vfs.BaseVirtualFolder, expectedStatusCode int) ([]byte, error) { var body []byte resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotasBasePath, "folders", folder.Name, "scan"), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // UpdateFolderQuotaUsage updates the folder used quota limits and checks the received HTTP Status code against expectedStatusCode. func UpdateFolderQuotaUsage(folder vfs.BaseVirtualFolder, mode string, expectedStatusCode int) ([]byte, error) { var body []byte folderAsJSON, _ := json.Marshal(folder) url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "folders", folder.Name, "usage"), mode) if err != nil { return body, err } resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(folderAsJSON), "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // GetVersion returns version details func GetVersion(expectedStatusCode int) (version.Info, []byte, error) { var appVersion version.Info var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(versionPath), nil, "", getDefaultToken()) if err != nil { return appVersion, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &appVersion) } else { body, _ = getResponseBody(resp) } return appVersion, body, err } // GetStatus returns the server status func GetStatus(expectedStatusCode int) (httpd.ServicesStatus, []byte, error) { var response httpd.ServicesStatus var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(serverStatusPath), nil, "", getDefaultToken()) if err != nil { return response, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && (expectedStatusCode == http.StatusOK) { err = render.DecodeJSON(resp.Body, &response) } else { body, _ = getResponseBody(resp) } return response, body, err } // GetDefenderHosts returns hosts that are banned or for which some violations have been detected func GetDefenderHosts(expectedStatusCode int) ([]dataprovider.DefenderEntry, []byte, error) { var response []dataprovider.DefenderEntry var body []byte url, err := url.Parse(buildURLRelativeToBase(defenderHosts)) if err != nil { return response, body, err } resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return response, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &response) } else { body, _ = getResponseBody(resp) } return response, body, err } // GetDefenderHostByIP returns the host with the given IP, if it exists func GetDefenderHostByIP(ip string, expectedStatusCode int) (dataprovider.DefenderEntry, []byte, error) { var host dataprovider.DefenderEntry var body []byte id := hex.EncodeToString([]byte(ip)) resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(defenderHosts, id), nil, "", getDefaultToken()) if err != nil { return host, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &host) } else { body, _ = getResponseBody(resp) } return host, body, err } // RemoveDefenderHostByIP removes the host with the given IP from the defender list func RemoveDefenderHostByIP(ip string, expectedStatusCode int) ([]byte, error) { var body []byte id := hex.EncodeToString([]byte(ip)) resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(defenderHosts, id), nil, "", getDefaultToken()) if err != nil { return body, err } defer resp.Body.Close() body, _ = getResponseBody(resp) return body, checkResponse(resp.StatusCode, expectedStatusCode) } // Dumpdata requests a backup to outputFile. // outputFile is relative to the configured backups_path func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int, scopes ...string) (map[string]any, []byte, error) { var response map[string]any var body []byte url, err := url.Parse(buildURLRelativeToBase(dumpDataPath)) if err != nil { return response, body, err } q := url.Query() if outputData != "" { q.Add("output-data", outputData) } if outputFile != "" { q.Add("output-file", outputFile) } if indent != "" { q.Add("indent", indent) } if len(scopes) > 0 { q.Add("scopes", strings.Join(scopes, ",")) } url.RawQuery = q.Encode() resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return response, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &response) } else { body, _ = getResponseBody(resp) } return response, body, err } // Loaddata restores a backup. func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]any, []byte, error) { var response map[string]any var body []byte url, err := url.Parse(buildURLRelativeToBase(loadDataPath)) if err != nil { return response, body, err } q := url.Query() q.Add("input-file", inputFile) if scanQuota != "" { q.Add("scan-quota", scanQuota) } if mode != "" { q.Add("mode", mode) } url.RawQuery = q.Encode() resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) if err != nil { return response, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &response) } else { body, _ = getResponseBody(resp) } return response, body, err } // LoaddataFromPostBody restores a backup func LoaddataFromPostBody(data []byte, scanQuota, mode string, expectedStatusCode int) (map[string]any, []byte, error) { var response map[string]any var body []byte url, err := url.Parse(buildURLRelativeToBase(loadDataPath)) if err != nil { return response, body, err } q := url.Query() if scanQuota != "" { q.Add("scan-quota", scanQuota) } if mode != "" { q.Add("mode", mode) } url.RawQuery = q.Encode() resp, err := sendHTTPRequest(http.MethodPost, url.String(), bytes.NewReader(data), "", getDefaultToken()) if err != nil { return response, body, err } defer resp.Body.Close() err = checkResponse(resp.StatusCode, expectedStatusCode) if err == nil && expectedStatusCode == http.StatusOK { err = render.DecodeJSON(resp.Body, &response) } else { body, _ = getResponseBody(resp) } return response, body, err } func checkResponse(actual int, expected int) error { if expected != actual { return fmt.Errorf("wrong status code: got %v want %v", actual, expected) } return nil } func getResponseBody(resp *http.Response) ([]byte, error) { return io.ReadAll(resp.Body) } func checkEventAction(expected, actual dataprovider.BaseEventAction) error { if expected.ID <= 0 { if actual.ID <= 0 { return errors.New("actual action ID must be > 0") } } else { if actual.ID != expected.ID { return errors.New("action ID mismatch") } } if dataprovider.ConvertName(expected.Name) != actual.Name { return errors.New("name mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } if expected.Type != actual.Type { return errors.New("type mismatch") } if expected.Options.PwdExpirationConfig.Threshold != actual.Options.PwdExpirationConfig.Threshold { return errors.New("password expiration threshold mismatch") } if expected.Options.UserInactivityConfig.DisableThreshold != actual.Options.UserInactivityConfig.DisableThreshold { return errors.New("user inactivity disable threshold mismatch") } if expected.Options.UserInactivityConfig.DeleteThreshold != actual.Options.UserInactivityConfig.DeleteThreshold { return errors.New("user inactivity delete threshold mismatch") } if err := compareEventActionIDPConfigFields(expected.Options.IDPConfig, actual.Options.IDPConfig); err != nil { return err } if err := compareEventActionCmdConfigFields(expected.Options.CmdConfig, actual.Options.CmdConfig); err != nil { return err } if err := compareEventActionEmailConfigFields(expected.Options.EmailConfig, actual.Options.EmailConfig); err != nil { return err } if err := compareEventActionDataRetentionFields(expected.Options.RetentionConfig, actual.Options.RetentionConfig); err != nil { return err } if err := compareEventActionFsConfigFields(expected.Options.FsConfig, actual.Options.FsConfig); err != nil { return err } return compareEventActionHTTPConfigFields(expected.Options.HTTPConfig, actual.Options.HTTPConfig) } func checkEventSchedules(expected, actual []dataprovider.Schedule) error { if len(expected) != len(actual) { return errors.New("schedules mismatch") } for _, ex := range expected { found := false for _, ac := range actual { if ac.DayOfMonth == ex.DayOfMonth && ac.DayOfWeek == ex.DayOfWeek && ac.Hours == ex.Hours && ac.Month == ex.Month { found = true break } } if !found { return errors.New("schedules content mismatch") } } return nil } func compareConditionPatternOptions(expected, actual []dataprovider.ConditionPattern) error { if len(expected) != len(actual) { return errors.New("condition pattern mismatch") } for _, ex := range expected { found := false for _, ac := range actual { if ac.Pattern == ex.Pattern && ac.InverseMatch == ex.InverseMatch { found = true break } } if !found { return errors.New("condition pattern content mismatch") } } return nil } func checkEventConditionOptions(expected, actual dataprovider.ConditionOptions) error { //nolint:gocyclo if err := compareConditionPatternOptions(expected.Names, actual.Names); err != nil { return errors.New("condition names mismatch") } if err := compareConditionPatternOptions(expected.GroupNames, actual.GroupNames); err != nil { return errors.New("condition group names mismatch") } if err := compareConditionPatternOptions(expected.RoleNames, actual.RoleNames); err != nil { return errors.New("condition role names mismatch") } if err := compareConditionPatternOptions(expected.FsPaths, actual.FsPaths); err != nil { return errors.New("condition fs_paths mismatch") } if len(expected.Protocols) != len(actual.Protocols) { return errors.New("condition protocols mismatch") } for _, v := range expected.Protocols { if !slices.Contains(actual.Protocols, v) { return errors.New("condition protocols content mismatch") } } if len(expected.EventStatuses) != len(actual.EventStatuses) { return errors.New("condition statuses mismatch") } for _, v := range expected.EventStatuses { if !slices.Contains(actual.EventStatuses, v) { return errors.New("condition statuses content mismatch") } } if len(expected.ProviderObjects) != len(actual.ProviderObjects) { return errors.New("condition provider objects mismatch") } for _, v := range expected.ProviderObjects { if !slices.Contains(actual.ProviderObjects, v) { return errors.New("condition provider objects content mismatch") } } if expected.MinFileSize != actual.MinFileSize { return errors.New("condition min file size mismatch") } if expected.MaxFileSize != actual.MaxFileSize { return errors.New("condition max file size mismatch") } return nil } func checkEventConditions(expected, actual dataprovider.EventConditions) error { if len(expected.FsEvents) != len(actual.FsEvents) { return errors.New("fs events mismatch") } for _, v := range expected.FsEvents { if !slices.Contains(actual.FsEvents, v) { return errors.New("fs events content mismatch") } } if len(expected.ProviderEvents) != len(actual.ProviderEvents) { return errors.New("provider events mismatch") } for _, v := range expected.ProviderEvents { if !slices.Contains(actual.ProviderEvents, v) { return errors.New("provider events content mismatch") } } if err := checkEventConditionOptions(expected.Options, actual.Options); err != nil { return err } if expected.IDPLoginEvent != actual.IDPLoginEvent { return errors.New("IDP login event mismatch") } return checkEventSchedules(expected.Schedules, actual.Schedules) } func checkEventRuleActions(expected, actual []dataprovider.EventAction) error { if len(expected) != len(actual) { return errors.New("actions mismatch") } for _, ex := range expected { found := false for _, ac := range actual { if ex.Name == ac.Name && ex.Order == ac.Order && ex.Options.ExecuteSync == ac.Options.ExecuteSync && ex.Options.IsFailureAction == ac.Options.IsFailureAction && ex.Options.StopOnFailure == ac.Options.StopOnFailure { found = true break } } if !found { return errors.New("actions contents mismatch") } } return nil } func checkEventRule(expected, actual dataprovider.EventRule) error { if expected.ID <= 0 { if actual.ID <= 0 { return errors.New("actual group ID must be > 0") } } else { if actual.ID != expected.ID { return errors.New("group ID mismatch") } } if dataprovider.ConvertName(expected.Name) != actual.Name { return errors.New("name mismatch") } if expected.Status != actual.Status { return errors.New("status mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } if actual.CreatedAt == 0 { return errors.New("created_at unset") } if actual.UpdatedAt == 0 { return errors.New("updated_at unset") } if expected.Trigger != actual.Trigger { return errors.New("trigger mismatch") } if err := checkEventConditions(expected.Conditions, actual.Conditions); err != nil { return err } return checkEventRuleActions(expected.Actions, actual.Actions) } func checkIPListEntry(expected, actual dataprovider.IPListEntry) error { if expected.IPOrNet != actual.IPOrNet { return errors.New("ipornet mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } if expected.Type != actual.Type { return errors.New("type mismatch") } if expected.Mode != actual.Mode { return errors.New("mode mismatch") } if expected.Protocols != actual.Protocols { return errors.New("protocols mismatch") } if actual.CreatedAt == 0 { return errors.New("created_at unset") } if actual.UpdatedAt == 0 { return errors.New("updated_at unset") } return nil } func checkRole(expected, actual dataprovider.Role) error { if expected.ID <= 0 { if actual.ID <= 0 { return errors.New("actual role ID must be > 0") } } else { if actual.ID != expected.ID { return errors.New("role ID mismatch") } } if dataprovider.ConvertName(expected.Name) != actual.Name { return errors.New("name mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } if actual.CreatedAt == 0 { return errors.New("created_at unset") } if actual.UpdatedAt == 0 { return errors.New("updated_at unset") } return nil } func checkGroup(expected, actual dataprovider.Group) error { if expected.ID <= 0 { if actual.ID <= 0 { return errors.New("actual group ID must be > 0") } } else { if actual.ID != expected.ID { return errors.New("group ID mismatch") } } if dataprovider.ConvertName(expected.Name) != actual.Name { return errors.New("name mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } if actual.CreatedAt == 0 { return errors.New("created_at unset") } if actual.UpdatedAt == 0 { return errors.New("updated_at unset") } if err := compareEqualGroupSettingsFields(expected.UserSettings.BaseGroupUserSettings, actual.UserSettings.BaseGroupUserSettings); err != nil { return err } if err := compareVirtualFolders(expected.VirtualFolders, actual.VirtualFolders); err != nil { return err } if err := compareUserFilters(expected.UserSettings.Filters, actual.UserSettings.Filters); err != nil { return err } return compareFsConfig(&expected.UserSettings.FsConfig, &actual.UserSettings.FsConfig) } func checkFolder(expected *vfs.BaseVirtualFolder, actual *vfs.BaseVirtualFolder) error { if expected.ID <= 0 { if actual.ID <= 0 { return errors.New("actual folder ID must be > 0") } } else { if actual.ID != expected.ID { return errors.New("folder ID mismatch") } } if dataprovider.ConvertName(expected.Name) != actual.Name { return errors.New("name mismatch") } if expected.MappedPath != actual.MappedPath { return errors.New("mapped path mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } return compareFsConfig(&expected.FsConfig, &actual.FsConfig) } func checkAPIKey(expected, actual *dataprovider.APIKey) error { if actual.Key != "" { return errors.New("key must not be visible") } if actual.KeyID == "" { return errors.New("actual key_id cannot be empty") } if expected.Name != actual.Name { return errors.New("name mismatch") } if expected.Scope != actual.Scope { return errors.New("scope mismatch") } if actual.CreatedAt == 0 { return errors.New("created_at cannot be 0") } if actual.UpdatedAt == 0 { return errors.New("updated_at cannot be 0") } if expected.ExpiresAt != actual.ExpiresAt { return errors.New("expires_at mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } if expected.User != actual.User { return errors.New("user mismatch") } if expected.Admin != actual.Admin { return errors.New("admin mismatch") } return nil } func checkAdmin(expected, actual *dataprovider.Admin) error { if actual.Password != "" { return errors.New("admin password must not be visible") } if expected.ID <= 0 { if actual.ID <= 0 { return errors.New("actual admin ID must be > 0") } } else { if actual.ID != expected.ID { return errors.New("admin ID mismatch") } } if expected.CreatedAt > 0 { if expected.CreatedAt != actual.CreatedAt { return fmt.Errorf("created_at mismatch %v != %v", expected.CreatedAt, actual.CreatedAt) } } if err := compareAdminEqualFields(expected, actual); err != nil { return err } if len(expected.Permissions) != len(actual.Permissions) { return errors.New("permissions mismatch") } for _, p := range expected.Permissions { if !slices.Contains(actual.Permissions, p) { return errors.New("permissions content mismatch") } } if err := compareAdminFilters(expected.Filters, actual.Filters); err != nil { return err } return compareAdminGroups(expected, actual) } func compareAdminFilters(expected, actual dataprovider.AdminFilters) error { if expected.AllowAPIKeyAuth != actual.AllowAPIKeyAuth { return errors.New("allow_api_key_auth mismatch") } if len(expected.AllowList) != len(actual.AllowList) { return errors.New("allow list mismatch") } for _, v := range expected.AllowList { if !slices.Contains(actual.AllowList, v) { return errors.New("allow list content mismatch") } } if expected.Preferences.HideUserPageSections != actual.Preferences.HideUserPageSections { return errors.New("hide user page sections mismatch") } if expected.Preferences.DefaultUsersExpiration != actual.Preferences.DefaultUsersExpiration { return errors.New("default users expiration mismatch") } if expected.RequirePasswordChange != actual.RequirePasswordChange { return errors.New("require password change mismatch") } if expected.RequireTwoFactor != actual.RequireTwoFactor { return errors.New("require two factor mismatch") } return nil } func compareAdminEqualFields(expected *dataprovider.Admin, actual *dataprovider.Admin) error { if dataprovider.ConvertName(expected.Username) != actual.Username { return errors.New("sername mismatch") } if expected.Email != actual.Email { return errors.New("email mismatch") } if expected.Status != actual.Status { return errors.New("status mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } if expected.AdditionalInfo != actual.AdditionalInfo { return errors.New("additional info mismatch") } if expected.Role != actual.Role { return errors.New("role mismatch") } return nil } func checkUser(expected *dataprovider.User, actual *dataprovider.User) error { if actual.Password != "" { return errors.New("user password must not be visible") } if expected.ID <= 0 { if actual.ID <= 0 { return errors.New("actual user ID must be > 0") } } else { if actual.ID != expected.ID { return errors.New("user ID mismatch") } } if expected.CreatedAt > 0 { if expected.CreatedAt != actual.CreatedAt { return fmt.Errorf("created_at mismatch %v != %v", expected.CreatedAt, actual.CreatedAt) } } if expected.Email != actual.Email { return errors.New("email mismatch") } if !slices.Equal(expected.Filters.AdditionalEmails, actual.Filters.AdditionalEmails) { return errors.New("additional emails mismatch") } if expected.Filters.RequirePasswordChange != actual.Filters.RequirePasswordChange { return errors.New("require_password_change mismatch") } if err := compareUserPermissions(expected.Permissions, actual.Permissions); err != nil { return err } if err := compareUserFilters(expected.Filters.BaseUserFilters, actual.Filters.BaseUserFilters); err != nil { return err } if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil { return err } if err := compareUserGroups(expected, actual); err != nil { return err } if err := compareVirtualFolders(expected.VirtualFolders, actual.VirtualFolders); err != nil { return err } return compareEqualsUserFields(expected, actual) } func compareUserPermissions(expected map[string][]string, actual map[string][]string) error { if len(expected) != len(actual) { return errors.New("permissions mismatch") } for dir, perms := range expected { if actualPerms, ok := actual[dir]; ok { for _, v := range actualPerms { if !slices.Contains(perms, v) { return errors.New("permissions contents mismatch") } } } else { return errors.New("permissions directories mismatch") } } return nil } func compareAdminGroups(expected *dataprovider.Admin, actual *dataprovider.Admin) error { if len(actual.Groups) != len(expected.Groups) { return errors.New("groups len mismatch") } for _, g := range actual.Groups { found := false for _, g1 := range expected.Groups { if g1.Name == g.Name { found = true if g1.Options.AddToUsersAs != g.Options.AddToUsersAs { return fmt.Errorf("add to users as field mismatch for group %s", g.Name) } } } if !found { return errors.New("groups mismatch") } } return nil } func compareUserGroups(expected *dataprovider.User, actual *dataprovider.User) error { if len(actual.Groups) != len(expected.Groups) { return errors.New("groups len mismatch") } for _, g := range actual.Groups { found := false for _, g1 := range expected.Groups { if g1.Name == g.Name { found = true if g1.Type != g.Type { return fmt.Errorf("type mismatch for group %s", g.Name) } } } if !found { return errors.New("groups mismatch") } } return nil } func compareVirtualFolders(expected []vfs.VirtualFolder, actual []vfs.VirtualFolder) error { if len(actual) != len(expected) { return errors.New("virtual folders len mismatch") } for _, v := range actual { found := false for _, v1 := range expected { if path.Clean(v.VirtualPath) == path.Clean(v1.VirtualPath) { if dataprovider.ConvertName(v1.Name) != v.Name { return errors.New("virtual folder name mismatch") } if v.QuotaSize != v1.QuotaSize { return errors.New("vfolder quota size mismatch") } if (v.QuotaFiles) != (v1.QuotaFiles) { return errors.New("vfolder quota files mismatch") } found = true break } } if !found { return errors.New("virtual folders mismatch") } } return nil } func compareFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { if expected.Provider != actual.Provider { return errors.New("fs provider mismatch") } if expected.OSConfig.ReadBufferSize != actual.OSConfig.ReadBufferSize { return fmt.Errorf("read buffer size mismatch") } if expected.OSConfig.WriteBufferSize != actual.OSConfig.WriteBufferSize { return fmt.Errorf("write buffer size mismatch") } if err := compareS3Config(expected, actual); err != nil { return err } if err := compareGCSConfig(expected, actual); err != nil { return err } if err := compareAzBlobConfig(expected, actual); err != nil { return err } if err := checkEncryptedSecret(expected.CryptConfig.Passphrase, actual.CryptConfig.Passphrase); err != nil { return err } if expected.CryptConfig.ReadBufferSize != actual.CryptConfig.ReadBufferSize { return fmt.Errorf("crypt read buffer size mismatch") } if expected.CryptConfig.WriteBufferSize != actual.CryptConfig.WriteBufferSize { return fmt.Errorf("crypt write buffer size mismatch") } if err := compareSFTPFsConfig(expected, actual); err != nil { return err } return compareHTTPFsConfig(expected, actual) } func compareS3Config(expected *vfs.Filesystem, actual *vfs.Filesystem) error { //nolint:gocyclo if expected.S3Config.Bucket != actual.S3Config.Bucket { return errors.New("fs S3 bucket mismatch") } if expected.S3Config.Region != actual.S3Config.Region { return errors.New("fs S3 region mismatch") } if expected.S3Config.AccessKey != actual.S3Config.AccessKey { return errors.New("fs S3 access key mismatch") } if expected.S3Config.RoleARN != actual.S3Config.RoleARN { return errors.New("fs S3 role ARN mismatch") } if err := checkEncryptedSecret(expected.S3Config.AccessSecret, actual.S3Config.AccessSecret); err != nil { return fmt.Errorf("fs S3 access secret mismatch: %v", err) } if err := checkEncryptedSecret(expected.S3Config.SSECustomerKey, actual.S3Config.SSECustomerKey); err != nil { return fmt.Errorf("fs S3 SSE customer key mismatch: %v", err) } if expected.S3Config.Endpoint != actual.S3Config.Endpoint { return errors.New("fs S3 endpoint mismatch") } if expected.S3Config.StorageClass != actual.S3Config.StorageClass { return errors.New("fs S3 storage class mismatch") } if expected.S3Config.ACL != actual.S3Config.ACL { return errors.New("fs S3 ACL mismatch") } if expected.S3Config.UploadPartSize != actual.S3Config.UploadPartSize { return errors.New("fs S3 upload part size mismatch") } if expected.S3Config.UploadConcurrency != actual.S3Config.UploadConcurrency { return errors.New("fs S3 upload concurrency mismatch") } if expected.S3Config.DownloadPartSize != actual.S3Config.DownloadPartSize { return errors.New("fs S3 download part size mismatch") } if expected.S3Config.DownloadConcurrency != actual.S3Config.DownloadConcurrency { return errors.New("fs S3 download concurrency mismatch") } if expected.S3Config.ForcePathStyle != actual.S3Config.ForcePathStyle { return errors.New("fs S3 force path style mismatch") } if expected.S3Config.SkipTLSVerify != actual.S3Config.SkipTLSVerify { return errors.New("fs S3 skip TLS verify mismatch") } if expected.S3Config.DownloadPartMaxTime != actual.S3Config.DownloadPartMaxTime { return errors.New("fs S3 download part max time mismatch") } if expected.S3Config.UploadPartMaxTime != actual.S3Config.UploadPartMaxTime { return errors.New("fs S3 upload part max time mismatch") } if expected.S3Config.KeyPrefix != actual.S3Config.KeyPrefix && expected.S3Config.KeyPrefix+"/" != actual.S3Config.KeyPrefix { return errors.New("fs S3 key prefix mismatch") } return nil } func compareGCSConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { if expected.GCSConfig.Bucket != actual.GCSConfig.Bucket { return errors.New("GCS bucket mismatch") } if expected.GCSConfig.StorageClass != actual.GCSConfig.StorageClass { return errors.New("GCS storage class mismatch") } if expected.GCSConfig.ACL != actual.GCSConfig.ACL { return errors.New("GCS ACL mismatch") } if expected.GCSConfig.KeyPrefix != actual.GCSConfig.KeyPrefix && expected.GCSConfig.KeyPrefix+"/" != actual.GCSConfig.KeyPrefix { return errors.New("GCS key prefix mismatch") } if expected.GCSConfig.AutomaticCredentials != actual.GCSConfig.AutomaticCredentials { return errors.New("GCS automatic credentials mismatch") } if expected.GCSConfig.UploadPartSize != actual.GCSConfig.UploadPartSize { return errors.New("GCS upload part size mismatch") } if expected.GCSConfig.UploadPartMaxTime != actual.GCSConfig.UploadPartMaxTime { return errors.New("GCS upload part max time mismatch") } return nil } func compareHTTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { if expected.HTTPConfig.Endpoint != actual.HTTPConfig.Endpoint { return errors.New("HTTPFs endpoint mismatch") } if expected.HTTPConfig.Username != actual.HTTPConfig.Username { return errors.New("HTTPFs username mismatch") } if expected.HTTPConfig.SkipTLSVerify != actual.HTTPConfig.SkipTLSVerify { return errors.New("HTTPFs skip_tls_verify mismatch") } if expected.SFTPConfig.EqualityCheckMode != actual.SFTPConfig.EqualityCheckMode { return errors.New("HTTPFs equality_check_mode mismatch") } if err := checkEncryptedSecret(expected.HTTPConfig.Password, actual.HTTPConfig.Password); err != nil { return fmt.Errorf("HTTPFs password mismatch: %v", err) } if err := checkEncryptedSecret(expected.HTTPConfig.APIKey, actual.HTTPConfig.APIKey); err != nil { return fmt.Errorf("HTTPFs API key mismatch: %v", err) } return nil } func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { if expected.SFTPConfig.Endpoint != actual.SFTPConfig.Endpoint { return errors.New("SFTPFs endpoint mismatch") } if expected.SFTPConfig.Username != actual.SFTPConfig.Username { return errors.New("SFTPFs username mismatch") } if expected.SFTPConfig.DisableCouncurrentReads != actual.SFTPConfig.DisableCouncurrentReads { return errors.New("SFTPFs disable_concurrent_reads mismatch") } if expected.SFTPConfig.BufferSize != actual.SFTPConfig.BufferSize { return errors.New("SFTPFs buffer_size mismatch") } if expected.SFTPConfig.EqualityCheckMode != actual.SFTPConfig.EqualityCheckMode { return errors.New("SFTPFs equality_check_mode mismatch") } if err := checkEncryptedSecret(expected.SFTPConfig.Password, actual.SFTPConfig.Password); err != nil { return fmt.Errorf("SFTPFs password mismatch: %v", err) } if err := checkEncryptedSecret(expected.SFTPConfig.PrivateKey, actual.SFTPConfig.PrivateKey); err != nil { return fmt.Errorf("SFTPFs private key mismatch: %v", err) } if err := checkEncryptedSecret(expected.SFTPConfig.KeyPassphrase, actual.SFTPConfig.KeyPassphrase); err != nil { return fmt.Errorf("SFTPFs private key passphrase mismatch: %v", err) } if expected.SFTPConfig.Prefix != actual.SFTPConfig.Prefix { if expected.SFTPConfig.Prefix != "" && actual.SFTPConfig.Prefix != "/" { return errors.New("SFTPFs prefix mismatch") } } if len(expected.SFTPConfig.Fingerprints) != len(actual.SFTPConfig.Fingerprints) { return errors.New("SFTPFs fingerprints mismatch") } for _, value := range actual.SFTPConfig.Fingerprints { if !slices.Contains(expected.SFTPConfig.Fingerprints, value) { return errors.New("SFTPFs fingerprints mismatch") } } return nil } func compareAzBlobConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { if expected.AzBlobConfig.Container != actual.AzBlobConfig.Container { return errors.New("azure Blob container mismatch") } if expected.AzBlobConfig.AccountName != actual.AzBlobConfig.AccountName { return errors.New("azure Blob account name mismatch") } if err := checkEncryptedSecret(expected.AzBlobConfig.AccountKey, actual.AzBlobConfig.AccountKey); err != nil { return fmt.Errorf("azure Blob account key mismatch: %v", err) } if expected.AzBlobConfig.Endpoint != actual.AzBlobConfig.Endpoint { return errors.New("azure Blob endpoint mismatch") } if err := checkEncryptedSecret(expected.AzBlobConfig.SASURL, actual.AzBlobConfig.SASURL); err != nil { return fmt.Errorf("azure Blob SAS URL mismatch: %v", err) } if expected.AzBlobConfig.UploadPartSize != actual.AzBlobConfig.UploadPartSize { return errors.New("azure Blob upload part size mismatch") } if expected.AzBlobConfig.UploadConcurrency != actual.AzBlobConfig.UploadConcurrency { return errors.New("azure Blob upload concurrency mismatch") } if expected.AzBlobConfig.DownloadPartSize != actual.AzBlobConfig.DownloadPartSize { return errors.New("azure Blob download part size mismatch") } if expected.AzBlobConfig.DownloadConcurrency != actual.AzBlobConfig.DownloadConcurrency { return errors.New("azure Blob download concurrency mismatch") } if expected.AzBlobConfig.KeyPrefix != actual.AzBlobConfig.KeyPrefix && expected.AzBlobConfig.KeyPrefix+"/" != actual.AzBlobConfig.KeyPrefix { return errors.New("azure Blob key prefix mismatch") } if expected.AzBlobConfig.UseEmulator != actual.AzBlobConfig.UseEmulator { return errors.New("azure Blob use emulator mismatch") } if expected.AzBlobConfig.AccessTier != actual.AzBlobConfig.AccessTier { return errors.New("azure Blob access tier mismatch") } return nil } func areSecretEquals(expected, actual *kms.Secret) bool { if expected == nil && actual == nil { return true } if expected != nil && expected.IsEmpty() && actual == nil { return true } if actual != nil && actual.IsEmpty() && expected == nil { return true } return false } func checkEncryptedSecret(expected, actual *kms.Secret) error { if areSecretEquals(expected, actual) { return nil } if expected == nil && actual != nil && !actual.IsEmpty() { return errors.New("secret mismatch") } if actual == nil && expected != nil && !expected.IsEmpty() { return errors.New("secret mismatch") } if expected.IsPlain() && actual.IsEncrypted() { if actual.GetPayload() == "" { return errors.New("invalid secret payload") } if actual.GetAdditionalData() != "" { return errors.New("invalid secret additional data") } if actual.GetKey() != "" { return errors.New("invalid secret key") } } else { if expected.GetStatus() != actual.GetStatus() || expected.GetPayload() != actual.GetPayload() { return errors.New("secret mismatch") } } return nil } func compareUserFilterSubStructs(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { for _, IPMask := range expected.AllowedIP { if !slices.Contains(actual.AllowedIP, IPMask) { return errors.New("allowed IP contents mismatch") } } for _, IPMask := range expected.DeniedIP { if !slices.Contains(actual.DeniedIP, IPMask) { return errors.New("denied IP contents mismatch") } } for _, method := range expected.DeniedLoginMethods { if !slices.Contains(actual.DeniedLoginMethods, method) { return errors.New("denied login methods contents mismatch") } } for _, protocol := range expected.DeniedProtocols { if !slices.Contains(actual.DeniedProtocols, protocol) { return errors.New("denied protocols contents mismatch") } } for _, options := range expected.WebClient { if !slices.Contains(actual.WebClient, options) { return errors.New("web client options contents mismatch") } } if len(expected.TLSCerts) != len(actual.TLSCerts) { return errors.New("TLS certs mismatch") } for _, cert := range expected.TLSCerts { if !slices.Contains(actual.TLSCerts, cert) { return errors.New("TLS certs content mismatch") } } return compareUserFiltersEqualFields(expected, actual) } func compareUserFiltersEqualFields(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { if expected.Hooks.ExternalAuthDisabled != actual.Hooks.ExternalAuthDisabled { return errors.New("external_auth_disabled hook mismatch") } if expected.Hooks.PreLoginDisabled != actual.Hooks.PreLoginDisabled { return errors.New("pre_login_disabled hook mismatch") } if expected.Hooks.CheckPasswordDisabled != actual.Hooks.CheckPasswordDisabled { return errors.New("check_password_disabled hook mismatch") } if expected.DisableFsChecks != actual.DisableFsChecks { return errors.New("disable_fs_checks mismatch") } if expected.StartDirectory != actual.StartDirectory { return errors.New("start_directory mismatch") } return nil } func compareBaseUserFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { //nolint:gocyclo if len(expected.AllowedIP) != len(actual.AllowedIP) { return errors.New("allowed IP mismatch") } if len(expected.DeniedIP) != len(actual.DeniedIP) { return errors.New("denied IP mismatch") } if len(expected.DeniedLoginMethods) != len(actual.DeniedLoginMethods) { return errors.New("denied login methods mismatch") } if len(expected.DeniedProtocols) != len(actual.DeniedProtocols) { return errors.New("denied protocols mismatch") } if expected.MaxUploadFileSize != actual.MaxUploadFileSize { return errors.New("max upload file size mismatch") } if expected.TLSUsername != actual.TLSUsername { return errors.New("TLSUsername mismatch") } if len(expected.WebClient) != len(actual.WebClient) { return errors.New("WebClient filter mismatch") } if expected.AllowAPIKeyAuth != actual.AllowAPIKeyAuth { return errors.New("allow_api_key_auth mismatch") } if expected.ExternalAuthCacheTime != actual.ExternalAuthCacheTime { return errors.New("external_auth_cache_time mismatch") } if expected.FTPSecurity != actual.FTPSecurity { return errors.New("ftp_security mismatch") } if expected.IsAnonymous != actual.IsAnonymous { return errors.New("is_anonymous mismatch") } if expected.DefaultSharesExpiration != actual.DefaultSharesExpiration { return errors.New("default_shares_expiration mismatch") } if expected.MaxSharesExpiration != actual.MaxSharesExpiration { return errors.New("max_shares_expiration mismatch") } if expected.PasswordExpiration != actual.PasswordExpiration { return errors.New("password_expiration mismatch") } if expected.PasswordStrength != actual.PasswordStrength { return errors.New("password_strength mismatch") } return nil } func compareUserFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { if err := compareBaseUserFilters(expected, actual); err != nil { return err } if err := compareUserFilterSubStructs(expected, actual); err != nil { return err } if err := compareUserBandwidthLimitFilters(expected, actual); err != nil { return err } if err := compareAccessTimeFilters(expected, actual); err != nil { return err } return compareUserFilePatternsFilters(expected, actual) } func checkFilterMatch(expected []string, actual []string) bool { if len(expected) != len(actual) { return false } for _, e := range expected { if !slices.Contains(actual, strings.ToLower(e)) { return false } } return true } func compareAccessTimeFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { if len(expected.AccessTime) != len(actual.AccessTime) { return errors.New("access time filters mismatch") } for idx, p := range expected.AccessTime { if actual.AccessTime[idx].DayOfWeek != p.DayOfWeek { return errors.New("access time day of week mismatch") } if actual.AccessTime[idx].From != p.From { return errors.New("access time from mismatch") } if actual.AccessTime[idx].To != p.To { return errors.New("access time to mismatch") } } return nil } func compareUserBandwidthLimitFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { if len(expected.BandwidthLimits) != len(actual.BandwidthLimits) { return errors.New("bandwidth limits filters mismatch") } for idx, l := range expected.BandwidthLimits { if actual.BandwidthLimits[idx].UploadBandwidth != l.UploadBandwidth { return errors.New("bandwidth filters upload_bandwidth mismatch") } if actual.BandwidthLimits[idx].DownloadBandwidth != l.DownloadBandwidth { return errors.New("bandwidth filters download_bandwidth mismatch") } if len(actual.BandwidthLimits[idx].Sources) != len(l.Sources) { return errors.New("bandwidth filters sources mismatch") } for _, source := range actual.BandwidthLimits[idx].Sources { if !slices.Contains(l.Sources, source) { return errors.New("bandwidth filters source mismatch") } } } return nil } func compareUserFilePatternsFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { if len(expected.FilePatterns) != len(actual.FilePatterns) { return errors.New("file patterns mismatch") } for _, f := range expected.FilePatterns { found := false for _, f1 := range actual.FilePatterns { if path.Clean(f.Path) == path.Clean(f1.Path) && f.DenyPolicy == f1.DenyPolicy { if !checkFilterMatch(f.AllowedPatterns, f1.AllowedPatterns) || !checkFilterMatch(f.DeniedPatterns, f1.DeniedPatterns) { return errors.New("file patterns contents mismatch") } found = true } } if !found { return errors.New("file patterns contents mismatch") } } return nil } func compareRenameConfigs(expected, actual []dataprovider.RenameConfig) error { if len(expected) != len(actual) { return errors.New("rename configs mismatch") } for _, ex := range expected { found := false for _, ac := range actual { if ac.Key == ex.Key && ac.Value == ex.Value && ac.UpdateModTime == ex.UpdateModTime { found = true break } } if !found { return errors.New("rename configs mismatch") } } return nil } func compareKeyValues(expected, actual []dataprovider.KeyValue) error { if len(expected) != len(actual) { return errors.New("key values mismatch") } for _, ex := range expected { found := false for _, ac := range actual { if ac.Key == ex.Key && ac.Value == ex.Value { found = true break } } if !found { return errors.New("key values mismatch") } } return nil } func compareHTTPparts(expected, actual []dataprovider.HTTPPart) error { for _, p1 := range expected { found := false for _, p2 := range actual { if p1.Name == p2.Name { found = true if err := compareKeyValues(p1.Headers, p2.Headers); err != nil { return fmt.Errorf("http headers mismatch for part %q", p1.Name) } if p1.Body != p2.Body || p1.Filepath != p2.Filepath { return fmt.Errorf("http part %q mismatch", p1.Name) } } } if !found { return fmt.Errorf("expected http part %q not found", p1.Name) } } return nil } func compareEventActionHTTPConfigFields(expected, actual dataprovider.EventActionHTTPConfig) error { if expected.Endpoint != actual.Endpoint { return errors.New("http endpoint mismatch") } if expected.Username != actual.Username { return errors.New("http username mismatch") } if err := checkEncryptedSecret(expected.Password, actual.Password); err != nil { return err } if err := compareKeyValues(expected.Headers, actual.Headers); err != nil { return errors.New("http headers mismatch") } if expected.Timeout != actual.Timeout { return errors.New("http timeout mismatch") } if expected.SkipTLSVerify != actual.SkipTLSVerify { return errors.New("http skip TLS verify mismatch") } if expected.Method != actual.Method { return errors.New("http method mismatch") } if err := compareKeyValues(expected.QueryParameters, actual.QueryParameters); err != nil { return errors.New("http query parameters mismatch") } if expected.Body != actual.Body { return errors.New("http body mismatch") } if len(expected.Parts) != len(actual.Parts) { return errors.New("http parts mismatch") } return compareHTTPparts(expected.Parts, actual.Parts) } func compareEventActionEmailConfigFields(expected, actual dataprovider.EventActionEmailConfig) error { if len(expected.Recipients) != len(actual.Recipients) { return errors.New("email recipients mismatch") } for _, v := range expected.Recipients { if !slices.Contains(actual.Recipients, v) { return errors.New("email recipients content mismatch") } } if len(expected.Bcc) != len(actual.Bcc) { return errors.New("email bcc mismatch") } for _, v := range expected.Bcc { if !slices.Contains(actual.Bcc, v) { return errors.New("email bcc content mismatch") } } if expected.Subject != actual.Subject { return errors.New("email subject mismatch") } if expected.ContentType != actual.ContentType { return errors.New("email content type mismatch") } if expected.Body != actual.Body { return errors.New("email body mismatch") } if len(expected.Attachments) != len(actual.Attachments) { return errors.New("email attachments mismatch") } for _, v := range expected.Attachments { if !slices.Contains(actual.Attachments, v) { return errors.New("email attachments content mismatch") } } return nil } func compareEventActionFsCompressFields(expected, actual dataprovider.EventActionFsCompress) error { if expected.Name != actual.Name { return errors.New("fs compress name mismatch") } if len(expected.Paths) != len(actual.Paths) { return errors.New("fs compress paths mismatch") } for _, v := range expected.Paths { if !slices.Contains(actual.Paths, v) { return errors.New("fs compress paths content mismatch") } } return nil } func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionFilesystemConfig) error { if expected.Type != actual.Type { return errors.New("fs type mismatch") } if err := compareRenameConfigs(expected.Renames, actual.Renames); err != nil { return errors.New("fs renames mismatch") } if err := compareKeyValues(expected.Copy, actual.Copy); err != nil { return errors.New("fs copy mismatch") } if len(expected.Deletes) != len(actual.Deletes) { return errors.New("fs deletes mismatch") } for _, v := range expected.Deletes { if !slices.Contains(actual.Deletes, v) { return errors.New("fs deletes content mismatch") } } if len(expected.MkDirs) != len(actual.MkDirs) { return errors.New("fs mkdirs mismatch") } for _, v := range expected.MkDirs { if !slices.Contains(actual.MkDirs, v) { return errors.New("fs mkdir content mismatch") } } if len(expected.Exist) != len(actual.Exist) { return errors.New("fs exist mismatch") } for _, v := range expected.Exist { if !slices.Contains(actual.Exist, v) { return errors.New("fs exist content mismatch") } } return compareEventActionFsCompressFields(expected.Compress, actual.Compress) } func compareEventActionIDPConfigFields(expected, actual dataprovider.EventActionIDPAccountCheck) error { if expected.Mode != actual.Mode { return errors.New("mode mismatch") } if expected.TemplateAdmin != actual.TemplateAdmin { return errors.New("admin template mismatch") } if expected.TemplateUser != actual.TemplateUser { return errors.New("user template mismatch") } return nil } func compareEventActionCmdConfigFields(expected, actual dataprovider.EventActionCommandConfig) error { if expected.Cmd != actual.Cmd { return errors.New("command mismatch") } if expected.Timeout != actual.Timeout { return errors.New("cmd timeout mismatch") } if len(expected.Args) != len(actual.Args) { return errors.New("cmd args mismatch") } for _, v := range expected.Args { if !slices.Contains(actual.Args, v) { return errors.New("cmd args content mismatch") } } if err := compareKeyValues(expected.EnvVars, actual.EnvVars); err != nil { return errors.New("cmd env vars mismatch") } return nil } func compareEventActionDataRetentionFields(expected, actual dataprovider.EventActionDataRetentionConfig) error { if len(expected.Folders) != len(actual.Folders) { return errors.New("retention folders mismatch") } for _, f1 := range expected.Folders { found := false for _, f2 := range actual.Folders { if f1.Path == f2.Path { found = true if f1.Retention != f2.Retention { return fmt.Errorf("retention mismatch for folder %s", f1.Path) } if f1.DeleteEmptyDirs != f2.DeleteEmptyDirs { return fmt.Errorf("delete_empty_dirs mismatch for folder %s", f1.Path) } break } } if !found { return errors.New("retention folders mismatch") } } return nil } func compareEqualGroupSettingsFields(expected sdk.BaseGroupUserSettings, actual sdk.BaseGroupUserSettings) error { if expected.HomeDir != actual.HomeDir { return errors.New("home dir mismatch") } if expected.MaxSessions != actual.MaxSessions { return errors.New("MaxSessions mismatch") } if expected.QuotaSize != actual.QuotaSize { return errors.New("QuotaSize mismatch") } if expected.QuotaFiles != actual.QuotaFiles { return errors.New("QuotaFiles mismatch") } if expected.UploadBandwidth != actual.UploadBandwidth { return errors.New("UploadBandwidth mismatch") } if expected.DownloadBandwidth != actual.DownloadBandwidth { return errors.New("DownloadBandwidth mismatch") } if expected.UploadDataTransfer != actual.UploadDataTransfer { return errors.New("upload_data_transfer mismatch") } if expected.DownloadDataTransfer != actual.DownloadDataTransfer { return errors.New("download_data_transfer mismatch") } if expected.TotalDataTransfer != actual.TotalDataTransfer { return errors.New("total_data_transfer mismatch") } if expected.ExpiresIn != actual.ExpiresIn { return errors.New("expires_in mismatch") } return compareUserPermissions(expected.Permissions, actual.Permissions) } func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.User) error { if dataprovider.ConvertName(expected.Username) != actual.Username { return errors.New("username mismatch") } if expected.HomeDir != actual.HomeDir { return errors.New("home dir mismatch") } if expected.UID != actual.UID { return errors.New("UID mismatch") } if expected.GID != actual.GID { return errors.New("GID mismatch") } if expected.MaxSessions != actual.MaxSessions { return errors.New("MaxSessions mismatch") } if len(expected.Permissions) != len(actual.Permissions) { return errors.New("permissions mismatch") } if expected.UploadBandwidth != actual.UploadBandwidth { return errors.New("UploadBandwidth mismatch") } if expected.DownloadBandwidth != actual.DownloadBandwidth { return errors.New("DownloadBandwidth mismatch") } if expected.Status != actual.Status { return errors.New("status mismatch") } if expected.ExpirationDate != actual.ExpirationDate { return errors.New("ExpirationDate mismatch") } if expected.AdditionalInfo != actual.AdditionalInfo { return errors.New("AdditionalInfo mismatch") } if expected.Description != actual.Description { return errors.New("description mismatch") } if expected.Role != actual.Role { return errors.New("role mismatch") } return compareQuotaUserFields(expected, actual) } func compareQuotaUserFields(expected *dataprovider.User, actual *dataprovider.User) error { if expected.QuotaSize != actual.QuotaSize { return errors.New("QuotaSize mismatch") } if expected.QuotaFiles != actual.QuotaFiles { return errors.New("QuotaFiles mismatch") } if expected.UploadDataTransfer != actual.UploadDataTransfer { return errors.New("upload_data_transfer mismatch") } if expected.DownloadDataTransfer != actual.DownloadDataTransfer { return errors.New("download_data_transfer mismatch") } if expected.TotalDataTransfer != actual.TotalDataTransfer { return errors.New("total_data_transfer mismatch") } return nil } func addLimitAndOffsetQueryParams(rawurl string, limit, offset int64) (*url.URL, error) { url, err := url.Parse(rawurl) if err != nil { return nil, err } q := url.Query() if limit > 0 { q.Add("limit", strconv.FormatInt(limit, 10)) } if offset > 0 { q.Add("offset", strconv.FormatInt(offset, 10)) } url.RawQuery = q.Encode() return url, err } func addModeQueryParam(rawurl, mode string) (*url.URL, error) { url, err := url.Parse(rawurl) if err != nil { return nil, err } q := url.Query() if len(mode) > 0 { q.Add("mode", mode) } url.RawQuery = q.Encode() return url, err } func addUpdateUserQueryParams(rawurl, disconnect string) (*url.URL, error) { url, err := url.Parse(rawurl) if err != nil { return nil, err } q := url.Query() if disconnect != "" { q.Add("disconnect", disconnect) } url.RawQuery = q.Encode() return url, err } ================================================ FILE: internal/httpdtest/httpfsimpl.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package httpdtest import ( "context" "errors" "fmt" "io" "mime" "net" "net/http" "net/url" "os" "path/filepath" "strconv" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" "github.com/shirou/gopsutil/v3/disk" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( statPath = "/api/v1/stat" openPath = "/api/v1/open" createPath = "/api/v1/create" renamePath = "/api/v1/rename" removePath = "/api/v1/remove" mkdirPath = "/api/v1/mkdir" chmodPath = "/api/v1/chmod" chtimesPath = "/api/v1/chtimes" truncatePath = "/api/v1/truncate" readdirPath = "/api/v1/readdir" dirsizePath = "/api/v1/dirsize" mimetypePath = "/api/v1/mimetype" statvfsPath = "/api/v1/statvfs" ) // HTTPFsCallbacks defines additional callbacks to customize the HTTPfs responses type HTTPFsCallbacks struct { Readdir func(string) []os.FileInfo } // StartTestHTTPFs starts a test HTTP service that implements httpfs // and listens on the specified port func StartTestHTTPFs(port int, callbacks *HTTPFsCallbacks) error { fs := httpFsImpl{ port: port, callbacks: callbacks, } return fs.Run() } // StartTestHTTPFsOverUnixSocket starts a test HTTP service that implements httpfs // and listens on the specified UNIX domain socket path func StartTestHTTPFsOverUnixSocket(socketPath string) error { fs := httpFsImpl{ unixSocketPath: socketPath, } return fs.Run() } type httpFsImpl struct { router *chi.Mux basePath string port int unixSocketPath string callbacks *HTTPFsCallbacks } type apiResponse struct { Error string `json:"error,omitempty"` Message string `json:"message,omitempty"` } func (fs *httpFsImpl) sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) { var errorString string if err != nil { errorString = err.Error() } resp := apiResponse{ Error: errorString, Message: message, } ctx := context.WithValue(r.Context(), render.StatusCtxKey, code) render.JSON(w, r.WithContext(ctx), resp) } func (fs *httpFsImpl) getUsername(r *http.Request) (string, error) { username, _, ok := r.BasicAuth() if !ok || username == "" { return "", os.ErrPermission } rootPath := filepath.Join(fs.basePath, username) _, err := os.Stat(rootPath) if errors.Is(err, os.ErrNotExist) { err = os.MkdirAll(rootPath, os.ModePerm) if err != nil { return username, err } } return username, nil } func (fs *httpFsImpl) getRespStatus(err error) int { if errors.Is(err, os.ErrPermission) { return http.StatusForbidden } if errors.Is(err, os.ErrNotExist) { return http.StatusNotFound } return http.StatusInternalServerError } func (fs *httpFsImpl) stat(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) info, err := os.Stat(fsPath) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } render.JSON(w, r, getStatFromInfo(info)) } func (fs *httpFsImpl) open(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } var offset int64 if r.URL.Query().Has("offset") { offset, err = strconv.ParseInt(r.URL.Query().Get("offset"), 10, 64) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) f, err := os.Open(fsPath) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } defer f.Close() if offset > 0 { _, err = f.Seek(offset, io.SeekStart) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } } ctype := mime.TypeByExtension(filepath.Ext(name)) if ctype != "" { ctype = "application/octet-stream" } w.Header().Set("Content-Type", ctype) _, err = io.Copy(w, f) if err != nil { panic(http.ErrAbortHandler) } } func (fs *httpFsImpl) create(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } flags := os.O_RDWR | os.O_CREATE | os.O_TRUNC if r.URL.Query().Has("flags") { openFlags, err := strconv.ParseInt(r.URL.Query().Get("flags"), 10, 32) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } if openFlags > 0 { flags = int(openFlags) } } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) f, err := os.OpenFile(fsPath, flags, 0666) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } defer f.Close() _, err = io.Copy(f, r.Body) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } fs.sendAPIResponse(w, r, nil, "upload OK", http.StatusOK) } func (fs *httpFsImpl) rename(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } target := r.URL.Query().Get("target") if target == "" { fs.sendAPIResponse(w, r, nil, "target path cannot be empty", http.StatusBadRequest) return } name := getNameURLParam(r) sourcePath := filepath.Join(fs.basePath, username, name) targetPath := filepath.Join(fs.basePath, username, target) err = os.Rename(sourcePath, targetPath) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } fs.sendAPIResponse(w, r, nil, "rename OK", http.StatusOK) } func (fs *httpFsImpl) remove(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) err = os.Remove(fsPath) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } fs.sendAPIResponse(w, r, nil, "remove OK", http.StatusOK) } func (fs *httpFsImpl) mkdir(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) err = os.Mkdir(fsPath, os.ModePerm) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } fs.sendAPIResponse(w, r, nil, "mkdir OK", http.StatusOK) } func (fs *httpFsImpl) chmod(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } mode, err := strconv.ParseUint(r.URL.Query().Get("mode"), 10, 32) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) err = os.Chmod(fsPath, os.FileMode(mode)) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } fs.sendAPIResponse(w, r, nil, "chmod OK", http.StatusOK) } func (fs *httpFsImpl) chtimes(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } atime, err := time.Parse(time.RFC3339, r.URL.Query().Get("access_time")) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } mtime, err := time.Parse(time.RFC3339, r.URL.Query().Get("modification_time")) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) err = os.Chtimes(fsPath, atime, mtime) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } fs.sendAPIResponse(w, r, nil, "chtimes OK", http.StatusOK) } func (fs *httpFsImpl) truncate(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } size, err := strconv.ParseInt(r.URL.Query().Get("size"), 10, 64) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) err = os.Truncate(fsPath, size) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } fs.sendAPIResponse(w, r, nil, "chmod OK", http.StatusOK) } func (fs *httpFsImpl) readdir(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) f, err := os.Open(fsPath) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } list, err := f.Readdir(-1) f.Close() if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } result := make([]map[string]any, 0, len(list)) for _, fi := range list { result = append(result, getStatFromInfo(fi)) } if fs.callbacks != nil && fs.callbacks.Readdir != nil { for _, fi := range fs.callbacks.Readdir(name) { result = append(result, getStatFromInfo(fi)) } } render.JSON(w, r, result) } func (fs *httpFsImpl) dirsize(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) info, err := os.Stat(fsPath) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } numFiles := 0 size := int64(0) if info.IsDir() { err = filepath.Walk(fsPath, func(_ string, info os.FileInfo, err error) error { if err != nil { return err } if info != nil && info.Mode().IsRegular() { size += info.Size() numFiles++ } return err }) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } } render.JSON(w, r, map[string]any{ "files": numFiles, "size": size, }) } func (fs *httpFsImpl) mimetype(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) f, err := os.OpenFile(fsPath, os.O_RDONLY, 0) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } defer f.Close() var buf [512]byte n, err := io.ReadFull(f, buf[:]) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } ctype := http.DetectContentType(buf[:n]) render.JSON(w, r, map[string]any{ "mime": ctype, }) } func (fs *httpFsImpl) statvfs(w http.ResponseWriter, r *http.Request) { username, err := fs.getUsername(r) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } name := getNameURLParam(r) fsPath := filepath.Join(fs.basePath, username, name) usage, err := disk.Usage(fsPath) if err != nil { fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) return } // we assume block size = 4096 bsize := uint64(4096) blocks := usage.Total / bsize bfree := usage.Free / bsize files := usage.InodesTotal ffree := usage.InodesFree if files == 0 { // these assumptions are wrong but still better than returning 0 files = blocks / 4 ffree = bfree / 4 } render.JSON(w, r, map[string]any{ "bsize": bsize, "frsize": bsize, "blocks": blocks, "bfree": bfree, "bavail": bfree, "files": files, "ffree": ffree, "favail": ffree, "namemax": 255, }) } func (fs *httpFsImpl) configureRouter() { fs.router = chi.NewRouter() fs.router.Use(middleware.Recoverer) fs.router.Get(statPath+"/{name}", fs.stat) //nolint:goconst fs.router.Get(openPath+"/{name}", fs.open) fs.router.Post(createPath+"/{name}", fs.create) fs.router.Patch(renamePath+"/{name}", fs.rename) fs.router.Delete(removePath+"/{name}", fs.remove) fs.router.Post(mkdirPath+"/{name}", fs.mkdir) fs.router.Patch(chmodPath+"/{name}", fs.chmod) fs.router.Patch(chtimesPath+"/{name}", fs.chtimes) fs.router.Patch(truncatePath+"/{name}", fs.truncate) fs.router.Get(readdirPath+"/{name}", fs.readdir) fs.router.Get(dirsizePath+"/{name}", fs.dirsize) fs.router.Get(mimetypePath+"/{name}", fs.mimetype) fs.router.Get(statvfsPath+"/{name}", fs.statvfs) } func (fs *httpFsImpl) Run() error { fs.basePath = filepath.Join(os.TempDir(), "httpfs") if err := os.RemoveAll(fs.basePath); err != nil { return err } if err := os.MkdirAll(fs.basePath, os.ModePerm); err != nil { return err } fs.configureRouter() httpServer := http.Server{ Addr: fmt.Sprintf(":%d", fs.port), Handler: fs.router, ReadTimeout: 60 * time.Second, WriteTimeout: 60 * time.Second, IdleTimeout: 120 * time.Second, MaxHeaderBytes: 1 << 16, // 64KB } if fs.unixSocketPath == "" { return httpServer.ListenAndServe() } err := os.Remove(fs.unixSocketPath) if err != nil && !os.IsNotExist(err) { return err } listener, err := net.Listen("unix", fs.unixSocketPath) if err != nil { return err } return httpServer.Serve(listener) } func getStatFromInfo(info os.FileInfo) map[string]any { return map[string]any{ "name": info.Name(), "size": info.Size(), "mode": info.Mode(), "last_modified": info.ModTime(), } } func getNameURLParam(r *http.Request) string { v := chi.URLParam(r, "name") unescaped, err := url.PathUnescape(v) if err != nil { return util.CleanPath(v) } return util.CleanPath(unescaped) } ================================================ FILE: internal/jwt/jwt.go ================================================ // Copyright (C) 2025 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package jwt provides functionality for creating, parsing, and validating // JSON Web Tokens (JWT) used in authentication and authorization workflows. package jwt import ( "context" "errors" "fmt" "net/http" "slices" "strings" "time" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "github.com/rs/xid" ) const ( CookieKey = "jwt" ) var ( TokenCtxKey = &contextKey{"Token"} ErrorCtxKey = &contextKey{"Error"} ) // contextKey is a value for use with context.WithValue. It's used as // a pointer so it fits in an interface{} without allocation. This technique // for defining context keys was copied from Go 1.7's new use of context in net/http. type contextKey struct { name string } func (k *contextKey) String() string { return "jwt context value " + k.name } func NewClaims(audience, ip string, duration time.Duration) *Claims { now := time.Now() claims := &Claims{} claims.IssuedAt = jwt.NewNumericDate(now) claims.NotBefore = jwt.NewNumericDate(now.Add(-10 * time.Second)) claims.Expiry = jwt.NewNumericDate(now.Add(duration)) claims.Audience = []string{audience, ip} return claims } type Claims struct { jwt.Claims Username string `json:"username,omitempty"` Permissions []string `json:"permissions,omitempty"` Role string `json:"role,omitempty"` APIKeyID string `json:"api_key,omitempty"` NodeID string `json:"node_id,omitempty"` MustSetTwoFactorAuth bool `json:"2fa_required,omitempty"` MustChangePassword bool `json:"chpwd,omitempty"` RequiredTwoFactorProtocols []string `json:"2fa_protos,omitempty"` HideUserPageSections int `json:"hus,omitempty"` Ref string `json:"ref,omitempty"` } func (c *Claims) SetIssuedAt(t time.Time) { c.IssuedAt = jwt.NewNumericDate(t) } func (c *Claims) SetNotBefore(t time.Time) { c.NotBefore = jwt.NewNumericDate(t) } func (c *Claims) SetExpiry(t time.Time) { c.Expiry = jwt.NewNumericDate(t) } func (c *Claims) HasPerm(perm string) bool { for _, p := range c.Permissions { if p == "*" || p == perm { return true } } return false } func (c *Claims) HasAnyAudience(audiences []string) bool { for _, a := range c.Audience { if slices.Contains(audiences, a) { return true } } return false } func (c *Claims) GenerateTokenResponse(signer *Signer) (TokenResponse, error) { token, err := signer.Sign(c) if err != nil { return TokenResponse{}, err } return c.BuildTokenResponse(token), nil } func (c *Claims) BuildTokenResponse(token string) TokenResponse { return TokenResponse{Token: token, Expiry: c.Expiry.Time().UTC().Format(time.RFC3339)} } type TokenResponse struct { Token string `json:"access_token"` Expiry string `json:"expires_at"` } func NewSigner(algo jose.SignatureAlgorithm, key any) (*Signer, error) { opts := (&jose.SignerOptions{}).WithType("JWT") signer, err := jose.NewSigner(jose.SigningKey{Algorithm: algo, Key: key}, opts) if err != nil { return nil, err } return &Signer{ signer: signer, algo: []jose.SignatureAlgorithm{algo}, key: key, }, nil } type Signer struct { algo []jose.SignatureAlgorithm signer jose.Signer key any } func (s *Signer) Sign(claims *Claims) (string, error) { if claims.ID == "" { claims.ID = xid.New().String() } if claims.IssuedAt == nil { claims.IssuedAt = jwt.NewNumericDate(time.Now()) } if claims.NotBefore == nil { claims.NotBefore = jwt.NewNumericDate(time.Now().Add(-10 * time.Second)) } if claims.Expiry == nil { return "", errors.New("expiration must be set") } if len(claims.Audience) == 0 { return "", errors.New("audience must be set") } return jwt.Signed(s.signer).Claims(claims).Serialize() } func (s *Signer) Signer() jose.Signer { return s.signer } func (s *Signer) SetSigner(signer jose.Signer) { s.signer = signer } func (s *Signer) SignWithParams(claims *Claims, audience, ip string, duration time.Duration) (string, error) { claims.Expiry = jwt.NewNumericDate(time.Now().Add(duration)) claims.Audience = []string{audience, ip} return s.Sign(claims) } func NewContext(ctx context.Context, claims *Claims, err error) context.Context { ctx = context.WithValue(ctx, TokenCtxKey, claims) ctx = context.WithValue(ctx, ErrorCtxKey, err) return ctx } func FromContext(ctx context.Context) (*Claims, error) { val := ctx.Value(TokenCtxKey) token, ok := val.(*Claims) if !ok && val != nil { return nil, fmt.Errorf("invalid type for TokenCtxKey: %T", val) } valErr := ctx.Value(ErrorCtxKey) err, ok := valErr.(error) if !ok && valErr != nil { return nil, fmt.Errorf("invalid type for ErrorCtxKey: %T", valErr) } if token == nil { return nil, errors.New("no token found") } return token, err } func Verify(s *Signer, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { hfn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() token, err := VerifyRequest(s, r, findTokenFns...) ctx = NewContext(ctx, token, err) next.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(hfn) } } func VerifyRequest(s *Signer, r *http.Request, findTokenFns ...func(r *http.Request) string) (*Claims, error) { var tokenString string for _, fn := range findTokenFns { tokenString = fn(r) if tokenString != "" { break } } if tokenString == "" { return nil, errors.New("no token found") } return VerifyToken(s, tokenString) } func VerifyToken(s *Signer, payload string) (*Claims, error) { return VerifyTokenWithKey(payload, s.algo, s.key) } func VerifyTokenWithKey(payload string, algo []jose.SignatureAlgorithm, key any) (*Claims, error) { token, err := jwt.ParseSigned(payload, algo) if err != nil { return nil, err } var claims Claims err = token.Claims(key, &claims) if err != nil { return nil, err } if err := claims.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 30*time.Second); err != nil { return nil, err } return &claims, nil } // TokenFromCookie tries to retrieve the token string from a cookie named // "jwt". func TokenFromCookie(r *http.Request) string { cookie, err := r.Cookie(CookieKey) if err != nil { return "" } return cookie.Value } // TokenFromHeader tries to retrieve the token string from the // "Authorization" request header: "Authorization: BEARER T". func TokenFromHeader(r *http.Request) string { // Get token from authorization header. bearer := r.Header.Get("Authorization") const prefix = "Bearer " if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) { return bearer[len(prefix):] } return "" } ================================================ FILE: internal/jwt/jwt_test.go ================================================ package jwt import ( "context" "errors" "fmt" "io/fs" "net/http" "net/http/httptest" "testing" "time" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/util" ) type failingJoseSigner struct{} func (s *failingJoseSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { return nil, errors.New("sign test error") } func (s *failingJoseSigner) Options() jose.SignerOptions { return jose.SignerOptions{} } func TestJWTToken(t *testing.T) { s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) username := util.GenerateUniqueID() claims := Claims{ Username: username, Claims: jwt.Claims{ Audience: jwt.Audience{"test"}, Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), NotBefore: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()), }, } token, err := s.Sign(&claims) require.NoError(t, err) require.NotEmpty(t, token) parsed, err := VerifyToken(s, token) require.NoError(t, err) require.Equal(t, username, parsed.Username) ja1, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) token, err = ja1.Sign(&claims) require.NoError(t, err) require.NotEmpty(t, token) _, err = VerifyToken(s, token) require.Error(t, err) _, err = VerifyToken(ja1, token) require.NoError(t, err) } func TestClaims(t *testing.T) { claims := NewClaims(util.GenerateUniqueID(), "", 10*time.Minute) s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) token, err := s.Sign(claims) require.NoError(t, err) assert.NotEmpty(t, token) assert.NotNil(t, claims.Expiry) assert.NotNil(t, claims.IssuedAt) assert.NotNil(t, claims.NotBefore) claims = &Claims{ Permissions: []string{"myperm"}, } claims.SetExpiry(time.Now().Add(1 * time.Minute)) claims.Audience = []string{"testaudience"} _, err = s.Sign(claims) assert.NoError(t, err) assert.NotNil(t, claims.IssuedAt) assert.NotNil(t, claims.NotBefore) assert.True(t, claims.HasAnyAudience([]string{util.GenerateUniqueID(), util.GenerateUniqueID(), "testaudience"})) assert.False(t, claims.HasAnyAudience([]string{util.GenerateUniqueID()})) assert.True(t, claims.HasPerm("myperm")) assert.False(t, claims.HasPerm(util.GenerateUniqueID())) resp, err := claims.GenerateTokenResponse(s) require.NoError(t, err) assert.NotEmpty(t, resp.Token) assert.Equal(t, claims.Expiry.Time().UTC().Format(time.RFC3339), resp.Expiry) claims.SetIssuedAt(time.Now()) claims.SetNotBefore(time.Now().Add(10 * time.Minute)) token, err = s.SignWithParams(claims, util.GenerateUniqueID(), "127.0.0.1", time.Minute) assert.NoError(t, err) _, err = VerifyToken(s, token) assert.ErrorContains(t, err, "nbf") claims = &Claims{} _, err = s.Sign(claims) assert.ErrorContains(t, err, "expiration must be set") claims.SetExpiry(time.Now()) _, err = s.Sign(claims) assert.ErrorContains(t, err, "audience must be set") claims = &Claims{} _, err = s.SignWithParams(claims, util.GenerateUniqueID(), "", time.Minute) assert.NoError(t, err) } func TestClaimsPermissions(t *testing.T) { c := Claims{ Permissions: []string{"*"}, } assert.True(t, c.HasPerm(util.GenerateUniqueID())) c.Permissions = []string{"list"} assert.False(t, c.HasPerm(util.GenerateUniqueID())) assert.True(t, c.HasPerm("list")) } func TestErrors(t *testing.T) { s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) _, err = VerifyToken(s, util.GenerateUniqueID()) assert.Error(t, err) claims := &Claims{} claims.SetExpiry(time.Now().Add(-1 * time.Minute)) token, err := jwt.Signed(s.Signer()).Claims(claims).Serialize() assert.NoError(t, err) _, err = VerifyToken(s, token) assert.ErrorContains(t, err, "exp") claims.SetExpiry(time.Now().Add(2 * time.Minute)) claims.SetIssuedAt(time.Now().Add(1 * time.Minute)) token, err = jwt.Signed(s.Signer()).Claims(claims).Serialize() assert.NoError(t, err) _, err = VerifyToken(s, token) assert.ErrorContains(t, err, "iat") claims.SetIssuedAt(time.Now()) claims.SetNotBefore(time.Now().Add(1 * time.Minute)) token, err = jwt.Signed(s.Signer()).Claims(claims).Serialize() assert.NoError(t, err) _, err = VerifyToken(s, token) assert.ErrorContains(t, err, "nbf") s.SetSigner(&failingJoseSigner{}) claims = NewClaims(util.GenerateUniqueID(), "", time.Minute) _, err = s.Sign(claims) assert.Error(t, err) _, err = claims.GenerateTokenResponse(s) assert.Error(t, err) // Wrong algorithm _, err = NewSigner("PS256", util.GenerateRandomBytes(32)) assert.Error(t, err) } func TestTokenFromRequest(t *testing.T) { claims := NewClaims(util.GenerateUniqueID(), "", 10*time.Minute) s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) token, err := s.Sign(claims) require.NoError(t, err) assert.NotEmpty(t, token) req, err := http.NewRequest(http.MethodGet, "/", nil) require.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) cookie := TokenFromCookie(req) assert.Equal(t, token, cookie) req, err = http.NewRequest(http.MethodGet, "/", nil) require.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) _, err = VerifyRequest(s, req, TokenFromHeader) assert.NoError(t, err) req.Header.Set("Authorization", token) assert.Empty(t, TokenFromHeader(req)) assert.Empty(t, TokenFromCookie(req)) _, err = VerifyRequest(s, req, TokenFromCookie) assert.ErrorContains(t, err, "no token found") } func TestContext(t *testing.T) { claims := &Claims{ Username: util.GenerateUniqueID(), } s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) token, err := s.SignWithParams(claims, util.GenerateUniqueID(), "", time.Minute) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, "/", nil) require.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) h := Verify(s, TokenFromHeader) wrapped := h(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token, err := FromContext(r.Context()) assert.Nil(t, err) assert.Equal(t, claims.Username, token.Username) w.WriteHeader(http.StatusOK) })) rr := httptest.NewRecorder() wrapped.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) _, err = FromContext(context.Background()) assert.ErrorContains(t, err, "no token found") ctx := NewContext(context.Background(), &Claims{}, fs.ErrClosed) _, err = FromContext(ctx) assert.Equal(t, fs.ErrClosed, err) ctx = context.WithValue(context.Background(), TokenCtxKey, "1") _, err = FromContext(ctx) assert.ErrorContains(t, err, "invalid type for TokenCtxKey") ctx = context.WithValue(context.Background(), ErrorCtxKey, 2) _, err = FromContext(ctx) assert.ErrorContains(t, err, "invalid type for ErrorCtxKey") claims = NewClaims(util.GenerateUniqueID(), "127.1.1.1", time.Minute) _, err = s.Sign(claims) require.NoError(t, err) ctx = context.WithValue(context.Background(), TokenCtxKey, claims) claimsFromContext, err := FromContext(ctx) assert.NoError(t, err) assert.Equal(t, claims, claimsFromContext) assert.Equal(t, "jwt context value Token", TokenCtxKey.String()) } func TestValidationLeeway(t *testing.T) { s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) require.NoError(t, err) claims := &Claims{} claims.Audience = []string{util.GenerateUniqueID()} claims.SetIssuedAt(time.Now().Add(10 * time.Second)) // issued at in the future claims.SetExpiry(time.Now().Add(10 * time.Second)) token, err := s.Sign(claims) require.NoError(t, err) _, err = VerifyToken(s, token) assert.NoError(t, err) claims = &Claims{} claims.Audience = []string{util.GenerateUniqueID()} claims.SetExpiry(time.Now().Add(-10 * time.Second)) // expired token, err = s.Sign(claims) require.NoError(t, err) _, err = VerifyToken(s, token) assert.NoError(t, err) claims = &Claims{} claims.Audience = []string{util.GenerateUniqueID()} claims.SetExpiry(time.Now().Add(30 * time.Second)) claims.SetNotBefore(time.Now().Add(10 * time.Second)) // not before in the future token, err = s.Sign(claims) require.NoError(t, err) _, err = VerifyToken(s, token) assert.NoError(t, err) } ================================================ FILE: internal/kms/basesecret.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package kms import ( sdkkms "github.com/sftpgo/sdk/kms" ) // BaseSecret defines the base struct shared among all the secret providers type BaseSecret struct { Status sdkkms.SecretStatus `json:"status,omitempty"` Payload string `json:"payload,omitempty"` Key string `json:"key,omitempty"` AdditionalData string `json:"additional_data,omitempty"` // 1 means encrypted using a master key Mode int `json:"mode,omitempty"` } // GetStatus returns the secret's status func (s *BaseSecret) GetStatus() sdkkms.SecretStatus { return s.Status } // GetPayload returns the secret's payload func (s *BaseSecret) GetPayload() string { return s.Payload } // GetKey returns the secret's key func (s *BaseSecret) GetKey() string { return s.Key } // GetMode returns the encryption mode func (s *BaseSecret) GetMode() int { return s.Mode } // GetAdditionalData returns the secret's additional data func (s *BaseSecret) GetAdditionalData() string { return s.AdditionalData } // SetKey sets the secret's key func (s *BaseSecret) SetKey(value string) { s.Key = value } // SetAdditionalData sets the secret's additional data func (s *BaseSecret) SetAdditionalData(value string) { s.AdditionalData = value } // SetStatus sets the secret's status func (s *BaseSecret) SetStatus(value sdkkms.SecretStatus) { s.Status = value } func (s *BaseSecret) isEmpty() bool { if s.Status != "" { return false } if s.Payload != "" { return false } if s.Key != "" { return false } if s.AdditionalData != "" { return false } return true } ================================================ FILE: internal/kms/builtin.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package kms import ( "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/sha256" "encoding/hex" "errors" "io" sdkkms "github.com/sftpgo/sdk/kms" "github.com/drakkan/sftpgo/v2/internal/util" ) var ( errMalformedCiphertext = errors.New("malformed ciphertext") ) type builtinSecret struct { BaseSecret } func init() { RegisterSecretProvider(sdkkms.SchemeBuiltin, sdkkms.SecretStatusAES256GCM, newBuiltinSecret) } func newBuiltinSecret(base BaseSecret, _, _ string) SecretProvider { return &builtinSecret{ BaseSecret: base, } } func (s *builtinSecret) Name() string { return "Builtin" } func (s *builtinSecret) IsEncrypted() bool { return s.Status == sdkkms.SecretStatusAES256GCM } func (s *builtinSecret) deriveKey(key []byte) []byte { var combined []byte combined = append(combined, key...) if s.AdditionalData != "" { combined = append(combined, []byte(s.AdditionalData)...) } combined = append(combined, key...) hash := sha256.Sum256(combined) return hash[:] } func (s *builtinSecret) Encrypt() error { if s.Payload == "" { return ErrInvalidSecret } switch s.Status { case sdkkms.SecretStatusPlain: key := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, key); err != nil { return err } block, err := aes.NewCipher(s.deriveKey(key)) if err != nil { return err } gcm, err := cipher.NewGCM(block) if err != nil { return err } nonce := make([]byte, gcm.NonceSize()) if _, err = io.ReadFull(rand.Reader, nonce); err != nil { return err } var aad []byte if s.AdditionalData != "" { aad = []byte(s.AdditionalData) } ciphertext := gcm.Seal(nonce, nonce, []byte(s.Payload), aad) s.Key = hex.EncodeToString(key) s.Payload = hex.EncodeToString(ciphertext) s.Status = sdkkms.SecretStatusAES256GCM return nil default: return ErrWrongSecretStatus } } func (s *builtinSecret) Decrypt() error { switch s.Status { case sdkkms.SecretStatusAES256GCM: encrypted, err := hex.DecodeString(s.Payload) if err != nil { return err } key, err := hex.DecodeString(s.Key) if err != nil { return err } block, err := aes.NewCipher(s.deriveKey(key)) if err != nil { return err } gcm, err := cipher.NewGCM(block) if err != nil { return err } nonceSize := gcm.NonceSize() if len(encrypted) < nonceSize { return errMalformedCiphertext } nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:] var aad []byte if s.AdditionalData != "" { aad = []byte(s.AdditionalData) } plaintext, err := gcm.Open(nil, nonce, ciphertext, aad) if err != nil { return err } s.Status = sdkkms.SecretStatusPlain s.Payload = util.BytesToString(plaintext) s.Key = "" s.AdditionalData = "" return nil default: return ErrWrongSecretStatus } } func (s *builtinSecret) Clone() SecretProvider { baseSecret := BaseSecret{ Status: s.Status, Payload: s.Payload, Key: s.Key, AdditionalData: s.AdditionalData, Mode: s.Mode, } return newBuiltinSecret(baseSecret, "", "") } ================================================ FILE: internal/kms/kms.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package kms provides Key Management Services support package kms import ( "encoding/json" "errors" "strings" "sync" sdkkms "github.com/sftpgo/sdk/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) // SecretProvider defines the interface for a KMS secrets provider type SecretProvider interface { Name() string Encrypt() error Decrypt() error IsEncrypted() bool GetStatus() sdkkms.SecretStatus GetPayload() string GetKey() string GetAdditionalData() string GetMode() int SetKey(string) SetAdditionalData(string) SetStatus(sdkkms.SecretStatus) Clone() SecretProvider } const ( logSender = "kms" ) // Configuration defines the KMS configuration type Configuration struct { Secrets Secrets `json:"secrets" mapstructure:"secrets"` } // Secrets define the KMS configuration for encryption/decryption type Secrets struct { URL string `json:"url" mapstructure:"url"` MasterKeyPath string `json:"master_key_path" mapstructure:"master_key_path"` MasterKeyString string `json:"master_key" mapstructure:"master_key"` masterKey string } type registeredSecretProvider struct { encryptedStatus sdkkms.SecretStatus newFn func(base BaseSecret, url, masterKey string) SecretProvider } var ( // ErrWrongSecretStatus defines the error to return if the secret status is not appropriate // for the request operation ErrWrongSecretStatus = errors.New("wrong secret status") // ErrInvalidSecret defines the error to return if a secret is not valid ErrInvalidSecret = errors.New("invalid secret") validSecretStatuses = []string{sdkkms.SecretStatusPlain, sdkkms.SecretStatusAES256GCM, sdkkms.SecretStatusSecretBox, sdkkms.SecretStatusVaultTransit, sdkkms.SecretStatusAWS, sdkkms.SecretStatusGCP, sdkkms.SecretStatusAzureKeyVault, sdkkms.SecretStatusOracleKeyVault, sdkkms.SecretStatusRedacted} config Configuration secretProviders = make(map[string]registeredSecretProvider) ) // RegisterSecretProvider register a new secret provider func RegisterSecretProvider(scheme string, encryptedStatus sdkkms.SecretStatus, fn func(base BaseSecret, url, masterKey string) SecretProvider, ) { secretProviders[scheme] = registeredSecretProvider{ encryptedStatus: encryptedStatus, newFn: fn, } } // NewSecret builds a new Secret using the provided arguments func NewSecret(status sdkkms.SecretStatus, payload, key, data string) *Secret { return config.newSecret(status, payload, key, data) } // NewEmptySecret returns an empty secret func NewEmptySecret() *Secret { return NewSecret("", "", "", "") } // NewPlainSecret stores the give payload in a plain text secret func NewPlainSecret(payload string) *Secret { return NewSecret(sdkkms.SecretStatusPlain, strings.TrimSpace(payload), "", "") } // Initialize configures the KMS support func (c *Configuration) Initialize() error { if c.Secrets.MasterKeyPath != "" { mKey, err := util.ReadConfigFromFile(c.Secrets.MasterKeyPath, "") if err != nil { return err } c.Secrets.masterKey = mKey } else if c.Secrets.MasterKeyString != "" { c.Secrets.masterKey = c.Secrets.MasterKeyString } config = *c if config.Secrets.URL == "" { config.Secrets.URL = sdkkms.SchemeLocal + "://" } for k, v := range secretProviders { logger.Info(logSender, "", "secret provider registered for scheme: %q, encrypted status: %q", k, v.encryptedStatus) } return nil } func (c *Configuration) newSecret(status sdkkms.SecretStatus, payload, key, data string) *Secret { base := BaseSecret{ Status: status, Key: key, Payload: payload, AdditionalData: data, } return &Secret{ provider: c.getSecretProvider(base), } } func (c *Configuration) getSecretProvider(base BaseSecret) SecretProvider { for k, v := range secretProviders { if strings.HasPrefix(c.Secrets.URL, k) { return v.newFn(base, c.Secrets.URL, c.Secrets.masterKey) } } logger.Warn(logSender, "", "no secret provider registered for URL %v, fallback to local provider", c.Secrets.URL) return NewLocalSecret(base, c.Secrets.URL, c.Secrets.masterKey) } // Secret defines the struct used to store confidential data type Secret struct { sync.RWMutex provider SecretProvider } // MarshalJSON return the JSON encoding of the Secret object func (s *Secret) MarshalJSON() ([]byte, error) { s.RLock() defer s.RUnlock() return json.Marshal(&BaseSecret{ Status: s.provider.GetStatus(), Payload: s.provider.GetPayload(), Key: s.provider.GetKey(), AdditionalData: s.provider.GetAdditionalData(), Mode: s.provider.GetMode(), }) } // UnmarshalJSON parses the JSON-encoded data and stores the result // in the Secret object func (s *Secret) UnmarshalJSON(data []byte) error { s.Lock() defer s.Unlock() baseSecret := BaseSecret{} err := json.Unmarshal(data, &baseSecret) if err != nil { return err } if baseSecret.isEmpty() { s.provider = config.getSecretProvider(baseSecret) return nil } if baseSecret.Status == sdkkms.SecretStatusPlain || baseSecret.Status == sdkkms.SecretStatusRedacted { s.provider = config.getSecretProvider(baseSecret) return nil } for _, v := range secretProviders { if v.encryptedStatus == baseSecret.Status { s.provider = v.newFn(baseSecret, config.Secrets.URL, config.Secrets.masterKey) return nil } } logger.Error(logSender, "", "no provider registered for status %q", baseSecret.Status) return ErrInvalidSecret } // IsEqual returns true if all the secrets fields are equal func (s *Secret) IsEqual(other *Secret) bool { if s.GetStatus() != other.GetStatus() { return false } if s.GetPayload() != other.GetPayload() { return false } if s.GetKey() != other.GetKey() { return false } if s.GetAdditionalData() != other.GetAdditionalData() { return false } if s.GetMode() != other.GetMode() { return false } return true } // Clone returns a copy of the secret object func (s *Secret) Clone() *Secret { s.RLock() defer s.RUnlock() return &Secret{ provider: s.provider.Clone(), } } // IsEncrypted returns true if the secret is encrypted // This isn't a pointer receiver because we don't want to pass // a pointer to html template func (s *Secret) IsEncrypted() bool { s.RLock() defer s.RUnlock() return s.provider.IsEncrypted() } // IsPlain returns true if the secret is in plain text func (s *Secret) IsPlain() bool { s.RLock() defer s.RUnlock() return s.provider.GetStatus() == sdkkms.SecretStatusPlain } // IsNotPlainAndNotEmpty returns true if the secret is not plain and not empty. // This is an utility method, we update the secret for an existing user // if it is empty or plain func (s *Secret) IsNotPlainAndNotEmpty() bool { s.RLock() defer s.RUnlock() return !s.IsPlain() && !s.IsEmpty() } // IsRedacted returns true if the secret is redacted func (s *Secret) IsRedacted() bool { s.RLock() defer s.RUnlock() return s.provider.GetStatus() == sdkkms.SecretStatusRedacted } // GetPayload returns the secret payload func (s *Secret) GetPayload() string { s.RLock() defer s.RUnlock() return s.provider.GetPayload() } // GetAdditionalData returns the secret additional data func (s *Secret) GetAdditionalData() string { s.RLock() defer s.RUnlock() return s.provider.GetAdditionalData() } // GetStatus returns the secret status func (s *Secret) GetStatus() sdkkms.SecretStatus { s.RLock() defer s.RUnlock() return s.provider.GetStatus() } // GetKey returns the secret key func (s *Secret) GetKey() string { s.RLock() defer s.RUnlock() return s.provider.GetKey() } // GetMode returns the secret mode func (s *Secret) GetMode() int { s.RLock() defer s.RUnlock() return s.provider.GetMode() } // SetAdditionalData sets the given additional data func (s *Secret) SetAdditionalData(value string) { s.Lock() defer s.Unlock() s.provider.SetAdditionalData(value) } // SetStatus sets the status for this secret func (s *Secret) SetStatus(value sdkkms.SecretStatus) { s.Lock() defer s.Unlock() s.provider.SetStatus(value) } // SetKey sets the key for this secret func (s *Secret) SetKey(value string) { s.Lock() defer s.Unlock() s.provider.SetKey(value) } // IsEmpty returns true if all fields are empty func (s *Secret) IsEmpty() bool { s.RLock() defer s.RUnlock() if s.provider.GetStatus() != "" { return false } if s.provider.GetPayload() != "" { return false } if s.provider.GetKey() != "" { return false } if s.provider.GetAdditionalData() != "" { return false } return true } // IsValid returns true if the secret is not empty and valid func (s *Secret) IsValid() bool { s.RLock() defer s.RUnlock() if !s.IsValidInput() { return false } switch s.provider.GetStatus() { case sdkkms.SecretStatusAES256GCM, sdkkms.SecretStatusSecretBox: if len(s.provider.GetKey()) != 64 { return false } case sdkkms.SecretStatusAWS, sdkkms.SecretStatusGCP, sdkkms.SecretStatusVaultTransit: key := s.provider.GetKey() if key != "" && len(key) != 64 { return false } } return true } // IsValidInput returns true if the secret is a valid user input func (s *Secret) IsValidInput() bool { s.RLock() defer s.RUnlock() if !isSecretStatusValid(s.provider.GetStatus()) { return false } if s.provider.GetPayload() == "" { return false } return true } // Hide hides info to decrypt data func (s *Secret) Hide() { s.Lock() defer s.Unlock() s.provider.SetKey("") s.provider.SetAdditionalData("") } // Encrypt encrypts a plain text Secret object func (s *Secret) Encrypt() error { s.Lock() defer s.Unlock() return s.provider.Encrypt() } // Decrypt decrypts a Secret object func (s *Secret) Decrypt() error { s.Lock() defer s.Unlock() return s.provider.Decrypt() } // TryDecrypt decrypts a Secret object if encrypted. // It returns a nil error if the object is not encrypted func (s *Secret) TryDecrypt() error { s.Lock() defer s.Unlock() if s.provider.IsEncrypted() { return s.provider.Decrypt() } return nil } func isSecretStatusValid(status string) bool { for idx := range validSecretStatuses { if validSecretStatuses[idx] == status { return true } } return false } ================================================ FILE: internal/kms/local.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package kms import ( "context" "crypto/sha256" "encoding/base64" "encoding/hex" "io" sdkkms "github.com/sftpgo/sdk/kms" "gocloud.dev/secrets/localsecrets" "golang.org/x/crypto/hkdf" "github.com/drakkan/sftpgo/v2/internal/util" ) func init() { RegisterSecretProvider(sdkkms.SchemeLocal, sdkkms.SecretStatusSecretBox, NewLocalSecret) } type localSecret struct { BaseSecret masterKey string } // NewLocalSecret returns a SecretProvider that use a locally provided symmetric key func NewLocalSecret(base BaseSecret, _, masterKey string) SecretProvider { return &localSecret{ BaseSecret: base, masterKey: masterKey, } } func (s *localSecret) Name() string { return "Local" } func (s *localSecret) IsEncrypted() bool { return s.Status == sdkkms.SecretStatusSecretBox } func (s *localSecret) Encrypt() error { if s.Status != sdkkms.SecretStatusPlain { return ErrWrongSecretStatus } if s.Payload == "" { return ErrInvalidSecret } secretKey, err := localsecrets.NewRandomKey() if err != nil { return err } key, err := s.deriveKey(secretKey[:], false) if err != nil { return err } keeper := localsecrets.NewKeeper(key) defer keeper.Close() ciphertext, err := keeper.Encrypt(context.Background(), []byte(s.Payload)) if err != nil { return err } s.Key = hex.EncodeToString(secretKey[:]) s.Payload = base64.StdEncoding.EncodeToString(ciphertext) s.Status = sdkkms.SecretStatusSecretBox s.Mode = s.getEncryptionMode() return nil } func (s *localSecret) Decrypt() error { if !s.IsEncrypted() { return ErrWrongSecretStatus } encrypted, err := base64.StdEncoding.DecodeString(s.Payload) if err != nil { return err } secretKey, err := hex.DecodeString(s.Key) if err != nil { return err } key, err := s.deriveKey(secretKey[:], true) if err != nil { return err } keeper := localsecrets.NewKeeper(key) defer keeper.Close() plaintext, err := keeper.Decrypt(context.Background(), encrypted) if err != nil { return err } s.Status = sdkkms.SecretStatusPlain s.Payload = util.BytesToString(plaintext) s.Key = "" s.AdditionalData = "" s.Mode = 0 return nil } func (s *localSecret) deriveKey(key []byte, isForDecryption bool) ([32]byte, error) { var masterKey []byte if s.masterKey == "" || (isForDecryption && s.Mode == 0) { var combined []byte combined = append(combined, key...) if s.AdditionalData != "" { combined = append(combined, []byte(s.AdditionalData)...) } combined = append(combined, key...) hash := sha256.Sum256(combined) masterKey = hash[:] } else { masterKey = []byte(s.masterKey) } var derivedKey [32]byte var info []byte if s.AdditionalData != "" { info = []byte(s.AdditionalData) } kdf := hkdf.New(sha256.New, masterKey, key, info) if _, err := io.ReadFull(kdf, derivedKey[:]); err != nil { return derivedKey, err } return derivedKey, nil } func (s *localSecret) getEncryptionMode() int { if s.masterKey == "" { return 0 } return 1 } func (s *localSecret) Clone() SecretProvider { baseSecret := BaseSecret{ Status: s.Status, Payload: s.Payload, Key: s.Key, AdditionalData: s.AdditionalData, Mode: s.Mode, } return NewLocalSecret(baseSecret, "", s.masterKey) } ================================================ FILE: internal/logger/hclog.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package logger import ( "io" "log" "github.com/hashicorp/go-hclog" "github.com/rs/zerolog" ) // HCLogAdapter is an adapter for hclog.Logger type HCLogAdapter struct { hclog.Logger } // Log emits a message and key/value pairs at a provided log level func (l *HCLogAdapter) Log(level hclog.Level, msg string, args ...any) { // Workaround to avoid logging plugin arguments that may contain sensitive data. // Check everytime we update go-plugin library. if msg == "starting plugin" { return } var ev *zerolog.Event switch level { case hclog.Info: ev = logger.Info() case hclog.Warn: ev = logger.Warn() case hclog.Error: ev = logger.Error() default: ev = logger.Debug() } ev.Timestamp().Str("sender", l.Name()) addKeysAndValues(ev, args...) ev.Msg(msg) } // Trace emits a message and key/value pairs at the TRACE level func (l *HCLogAdapter) Trace(msg string, args ...any) { l.Log(hclog.Debug, msg, args...) } // Debug emits a message and key/value pairs at the DEBUG level func (l *HCLogAdapter) Debug(msg string, args ...any) { l.Log(hclog.Debug, msg, args...) } // Info emits a message and key/value pairs at the INFO level func (l *HCLogAdapter) Info(msg string, args ...any) { l.Log(hclog.Info, msg, args...) } // Warn emits a message and key/value pairs at the WARN level func (l *HCLogAdapter) Warn(msg string, args ...any) { l.Log(hclog.Warn, msg, args...) } // Error emits a message and key/value pairs at the ERROR level func (l *HCLogAdapter) Error(msg string, args ...any) { l.Log(hclog.Error, msg, args...) } // With creates a sub-logger func (l *HCLogAdapter) With(args ...any) hclog.Logger { return &HCLogAdapter{Logger: l.Logger.With(args...)} } // Named creates a logger that will prepend the name string on the front of all messages func (l *HCLogAdapter) Named(name string) hclog.Logger { return &HCLogAdapter{Logger: l.Logger.Named(name)} } // StandardLogger returns a value that conforms to the stdlib log.Logger interface func (l *HCLogAdapter) StandardLogger(_ *hclog.StandardLoggerOptions) *log.Logger { return log.New(&StdLoggerWrapper{Sender: l.Name()}, "", 0) } // StandardWriter returns a value that conforms to io.Writer, which can be passed into log.SetOutput() func (l *HCLogAdapter) StandardWriter(_ *hclog.StandardLoggerOptions) io.Writer { return &StdLoggerWrapper{Sender: l.Name()} } ================================================ FILE: internal/logger/lego.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package logger import "fmt" const ( legoLogSender = "lego" ) // LegoAdapter is an adapter for lego.StdLogger type LegoAdapter struct { LogToConsole bool } // Fatal emits a log at Error level func (l *LegoAdapter) Fatal(args ...any) { if l.LogToConsole { ErrorToConsole("%s", fmt.Sprint(args...)) return } Log(LevelError, legoLogSender, "", "%s", fmt.Sprint(args...)) } // Fatalln is the same as Fatal func (l *LegoAdapter) Fatalln(args ...any) { l.Fatal(args...) } // Fatalf emits a log at Error level func (l *LegoAdapter) Fatalf(format string, args ...any) { if l.LogToConsole { ErrorToConsole(format, args...) return } Log(LevelError, legoLogSender, "", format, args...) } // Print emits a log at Info level func (l *LegoAdapter) Print(args ...any) { if l.LogToConsole { InfoToConsole("%s", fmt.Sprint(args...)) return } Log(LevelInfo, legoLogSender, "", "%s", fmt.Sprint(args...)) } // Println is the same as Print func (l *LegoAdapter) Println(args ...any) { l.Print(args...) } // Printf emits a log at Info level func (l *LegoAdapter) Printf(format string, args ...any) { if l.LogToConsole { InfoToConsole(format, args...) return } Log(LevelInfo, legoLogSender, "", format, args...) } ================================================ FILE: internal/logger/logger.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package logger provides logging capabilities. // It is a wrapper around zerolog for logging and lumberjack for log rotation. // Logs are written to the specified log file. // Logging on the console is provided to print initialization info, errors and warnings. // The package provides a request logger to log the HTTP requests for REST API too. // The request logger uses chi.middleware.RequestLogger, // chi.middleware.LogFormatter and chi.middleware.LogEntry to build a structured // logger using zerolog package logger import ( "errors" "fmt" "io/fs" "os" "path/filepath" "time" "github.com/rs/zerolog" lumberjack "gopkg.in/natefinch/lumberjack.v2" ) const ( dateFormat = "2006-01-02T15:04:05.000" // YYYY-MM-DDTHH:MM:SS.ZZZ ) // LogLevel defines log levels. type LogLevel uint8 // defines our own log levels, just in case we'll change logger in future const ( LevelDebug LogLevel = iota LevelInfo LevelWarn LevelError ) var ( logger zerolog.Logger consoleLogger zerolog.Logger rollingLogger *lumberjack.Logger ) func init() { zerolog.TimeFieldFormat = dateFormat } // GetLogger get the configured logger instance func GetLogger() *zerolog.Logger { return &logger } // InitLogger configures the logger using the given parameters func InitLogger(logFilePath string, logMaxSize int, logMaxBackups int, logMaxAge int, logCompress, logUTCTime bool, level zerolog.Level, ) { SetLogTime(logUTCTime) if isLogFilePathValid(logFilePath) { logDir := filepath.Dir(logFilePath) if _, err := os.Stat(logDir); errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(logDir, os.ModePerm) if err != nil { fmt.Printf("unable to create log dir %q: %v", logDir, err) } } rollingLogger = &lumberjack.Logger{ Filename: logFilePath, MaxSize: logMaxSize, MaxBackups: logMaxBackups, MaxAge: logMaxAge, Compress: logCompress, LocalTime: !logUTCTime, } logger = zerolog.New(rollingLogger) EnableConsoleLogger(level) } else { logger = zerolog.New(&logSyncWrapper{ output: os.Stdout, }) consoleLogger = zerolog.Nop() } logger = logger.Level(level) } // InitStdErrLogger configures the logger to write to stderr func InitStdErrLogger(level zerolog.Level) { logger = zerolog.New(&logSyncWrapper{ output: os.Stderr, }).Level(level) consoleLogger = zerolog.Nop() } // DisableLogger disable the main logger. // ConsoleLogger will not be affected func DisableLogger() { logger = zerolog.Nop() rollingLogger = nil } // EnableConsoleLogger enables the console logger func EnableConsoleLogger(level zerolog.Level) { consoleOutput := zerolog.ConsoleWriter{ Out: os.Stdout, TimeFormat: dateFormat, } consoleLogger = zerolog.New(consoleOutput).With().Timestamp().Logger().Level(level) } // RotateLogFile closes the existing log file and immediately create a new one func RotateLogFile() error { if rollingLogger != nil { return rollingLogger.Rotate() } return errors.New("logging to file is disabled") } // SetLogTime sets logging time related setting func SetLogTime(utc bool) { if utc { zerolog.TimestampFunc = func() time.Time { return time.Now().UTC() } } else { zerolog.TimestampFunc = time.Now } } // Log logs at the specified level for the specified sender func Log(level LogLevel, sender string, connectionID string, format string, v ...any) { var ev *zerolog.Event switch level { case LevelDebug: ev = logger.Debug() case LevelInfo: ev = logger.Info() case LevelWarn: ev = logger.Warn() default: ev = logger.Error() } ev.Timestamp().Str("sender", sender) if connectionID != "" { ev.Str("connection_id", connectionID) } ev.Msg(fmt.Sprintf(format, v...)) } // Debug logs at debug level for the specified sender func Debug(sender, connectionID, format string, v ...any) { Log(LevelDebug, sender, connectionID, format, v...) } // Info logs at info level for the specified sender func Info(sender, connectionID, format string, v ...any) { Log(LevelInfo, sender, connectionID, format, v...) } // Warn logs at warn level for the specified sender func Warn(sender, connectionID, format string, v ...any) { Log(LevelWarn, sender, connectionID, format, v...) } // Error logs at error level for the specified sender func Error(sender, connectionID, format string, v ...any) { Log(LevelError, sender, connectionID, format, v...) } // DebugToConsole logs at debug level to stdout func DebugToConsole(format string, v ...any) { consoleLogger.Debug().Msg(fmt.Sprintf(format, v...)) } // InfoToConsole logs at info level to stdout func InfoToConsole(format string, v ...any) { consoleLogger.Info().Msg(fmt.Sprintf(format, v...)) } // WarnToConsole logs at info level to stdout func WarnToConsole(format string, v ...any) { consoleLogger.Warn().Msg(fmt.Sprintf(format, v...)) } // ErrorToConsole logs at error level to stdout func ErrorToConsole(format string, v ...any) { consoleLogger.Error().Msg(fmt.Sprintf(format, v...)) } // TransferLog logs uploads or downloads func TransferLog(operation, path string, elapsed int64, size int64, user, connectionID, protocol, localAddr, remoteAddr, ftpMode string, err error, ) { var ev *zerolog.Event if err != nil { ev = logger.Error() } else { ev = logger.Info() } ev. Timestamp(). Str("sender", operation). Str("local_addr", localAddr). Str("remote_addr", remoteAddr). Int64("elapsed_ms", elapsed). Int64("size_bytes", size). Str("username", user). Str("file_path", path). Str("connection_id", connectionID). Str("protocol", protocol) if ftpMode != "" { ev.Str("ftp_mode", ftpMode) } ev.AnErr("error", err).Send() } // CommandLog logs an SFTP/SCP/SSH command func CommandLog(command, path, target, user, fileMode, connectionID, protocol string, uid, gid int, atime, mtime, sshCommand string, size int64, localAddr, remoteAddr string, elapsed int64) { logger.Info(). Timestamp(). Str("sender", command). Str("local_addr", localAddr). Str("remote_addr", remoteAddr). Str("username", user). Str("file_path", path). Str("target_path", target). Str("filemode", fileMode). Int("uid", uid). Int("gid", gid). Str("access_time", atime). Str("modification_time", mtime). Int64("size", size). Int64("elapsed", elapsed). Str("ssh_command", sshCommand). Str("connection_id", connectionID). Str("protocol", protocol). Send() } // ConnectionFailedLog logs failed attempts to initialize a connection. // A connection can fail for an authentication error or other errors such as // a client abort or a time out if the login does not happen in two minutes. // These logs are useful for better integration with Fail2ban and similar tools. func ConnectionFailedLog(user, ip, loginType, protocol, errorString string) { logger.Debug(). Timestamp(). Str("sender", "connection_failed"). Str("client_ip", ip). Str("username", user). Str("login_type", loginType). Str("protocol", protocol). Str("error", errorString). Send() } // LoginLog logs successful logins. func LoginLog(user, ip, loginMethod, protocol, connectionID, clientVersion string, encrypted bool, info string) { ev := logger.Info() ev.Timestamp(). Str("sender", "login"). Str("ip", ip). Str("username", user). Str("method", loginMethod). Str("protocol", protocol) if connectionID != "" { ev.Str("connection_id", connectionID) } ev.Str("client", clientVersion). Bool("encrypted", encrypted) if info != "" { ev.Str("info", info) } ev.Send() } func isLogFilePathValid(logFilePath string) bool { cleanInput := filepath.Clean(logFilePath) if cleanInput == "." || cleanInput == ".." { return false } return true } // StdLoggerWrapper is a wrapper for standard logger compatibility type StdLoggerWrapper struct { Sender string } // Write implements the io.Writer interface. This is useful to set as a writer // for the standard library log. func (l *StdLoggerWrapper) Write(p []byte) (n int, err error) { n = len(p) if n > 0 && p[n-1] == '\n' { // Trim CR added by stdlog. p = p[0 : n-1] } Log(LevelError, l.Sender, "", "%s", p) return } // LeveledLogger is a logger that accepts a message string and a variadic number of key-value pairs type LeveledLogger struct { Sender string additionalKeyVals []any } func addKeysAndValues(ev *zerolog.Event, keysAndValues ...any) { kvLen := len(keysAndValues) if kvLen%2 != 0 { extra := keysAndValues[kvLen-1] keysAndValues = append(keysAndValues[:kvLen-1], "EXTRA_VALUE_AT_END", extra) } for i := 0; i < len(keysAndValues); i += 2 { key, val := keysAndValues[i], keysAndValues[i+1] if keyStr, ok := key.(string); ok && keyStr != "timestamp" { ev.Str(keyStr, fmt.Sprintf("%v", val)) } } } // Error logs at error level for the specified sender func (l *LeveledLogger) Error(msg string, keysAndValues ...any) { ev := logger.Error() ev.Timestamp().Str("sender", l.Sender) if len(l.additionalKeyVals) > 0 { addKeysAndValues(ev, l.additionalKeyVals...) } addKeysAndValues(ev, keysAndValues...) ev.Msg(msg) } // Info logs at info level for the specified sender func (l *LeveledLogger) Info(msg string, keysAndValues ...any) { ev := logger.Info() ev.Timestamp().Str("sender", l.Sender) if len(l.additionalKeyVals) > 0 { addKeysAndValues(ev, l.additionalKeyVals...) } addKeysAndValues(ev, keysAndValues...) ev.Msg(msg) } // Debug logs at debug level for the specified sender func (l *LeveledLogger) Debug(msg string, keysAndValues ...any) { ev := logger.Debug() ev.Timestamp().Str("sender", l.Sender) if len(l.additionalKeyVals) > 0 { addKeysAndValues(ev, l.additionalKeyVals...) } addKeysAndValues(ev, keysAndValues...) ev.Msg(msg) } // Warn logs at warn level for the specified sender func (l *LeveledLogger) Warn(msg string, keysAndValues ...any) { ev := logger.Warn() ev.Timestamp().Str("sender", l.Sender) if len(l.additionalKeyVals) > 0 { addKeysAndValues(ev, l.additionalKeyVals...) } addKeysAndValues(ev, keysAndValues...) ev.Msg(msg) } // Panic logs the panic at error level for the specified sender func (l *LeveledLogger) Panic(msg string, keysAndValues ...any) { l.Error(msg, keysAndValues...) } ================================================ FILE: internal/logger/mail.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package logger import ( "fmt" "github.com/wneessen/go-mail/log" ) const ( mailLogSender = "smtpclient" ) // MailAdapter is an adapter for mail.Logger type MailAdapter struct { ConnectionID string } // Errorf emits a log at Error level func (l *MailAdapter) Errorf(logMsg log.Log) { format := l.getFormatString(&logMsg) ErrorToConsole(format, logMsg.Messages...) Log(LevelError, mailLogSender, l.ConnectionID, format, logMsg.Messages...) } // Warnf emits a log at Warn level func (l *MailAdapter) Warnf(logMsg log.Log) { format := l.getFormatString(&logMsg) WarnToConsole(format, logMsg.Messages...) Log(LevelWarn, mailLogSender, l.ConnectionID, format, logMsg.Messages...) } // Infof emits a log at Info level func (l *MailAdapter) Infof(logMsg log.Log) { format := l.getFormatString(&logMsg) InfoToConsole(format, logMsg.Messages...) Log(LevelInfo, mailLogSender, l.ConnectionID, format, logMsg.Messages...) } // Debugf emits a log at Debug level func (l *MailAdapter) Debugf(logMsg log.Log) { format := l.getFormatString(&logMsg) DebugToConsole(format, logMsg.Messages...) Log(LevelDebug, mailLogSender, l.ConnectionID, format, logMsg.Messages...) } func (*MailAdapter) getFormatString(logMsg *log.Log) string { p := "C <-- S:" if logMsg.Direction == log.DirClientToServer { p = "C --> S:" } return fmt.Sprintf("%s %s", p, logMsg.Format) } ================================================ FILE: internal/logger/request_logger.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package logger import ( "crypto/tls" "fmt" "net" "net/http" "time" "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog" "github.com/drakkan/sftpgo/v2/internal/metric" ) // StructuredLogger defines a simple wrapper around zerolog logger. // It implements chi.middleware.LogFormatter interface type StructuredLogger struct { Logger *zerolog.Logger } // StructuredLoggerEntry defines a log entry. // It implements chi.middleware.LogEntry interface type StructuredLoggerEntry struct { // The zerolog logger Logger *zerolog.Logger // fields to write in the log fields map[string]any } // NewStructuredLogger returns a chi.middleware.RequestLogger using our StructuredLogger. // This structured logger is called by the chi.middleware.Logger handler to log each HTTP request func NewStructuredLogger(logger *zerolog.Logger) func(next http.Handler) http.Handler { return middleware.RequestLogger(&StructuredLogger{logger}) } // NewLogEntry creates a new log entry for an HTTP request func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry { scheme := "http" cipherSuite := "" if r.TLS != nil { scheme = "https" cipherSuite = tls.CipherSuiteName(r.TLS.CipherSuite) } fields := map[string]any{ "local_addr": getLocalAddress(r), "remote_addr": r.RemoteAddr, "proto": r.Proto, "method": r.Method, "user_agent": r.UserAgent(), "uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI), "cipher_suite": cipherSuite, } reqID := middleware.GetReqID(r.Context()) if reqID != "" { fields["request_id"] = reqID } return &StructuredLoggerEntry{Logger: l.Logger, fields: fields} } // Write logs a new entry at the end of the HTTP request func (l *StructuredLoggerEntry) Write(status, bytes int, _ http.Header, elapsed time.Duration, _ any) { metric.HTTPRequestServed(status) var ev *zerolog.Event if status >= http.StatusInternalServerError { ev = l.Logger.Error() } else if status >= http.StatusBadRequest { ev = l.Logger.Warn() } else { ev = l.Logger.Debug() } ev. Timestamp(). Str("sender", "httpd"). Fields(l.fields). Int("resp_status", status). Int("resp_size", bytes). Int64("elapsed_ms", elapsed.Nanoseconds()/1000000). Send() } // Panic logs panics func (l *StructuredLoggerEntry) Panic(v any, stack []byte) { l.Logger.Error(). Timestamp(). Str("sender", "httpd"). Fields(l.fields). Str("stack", string(stack)). Str("panic", fmt.Sprintf("%+v", v)). Send() } func getLocalAddress(r *http.Request) string { if r == nil { return "" } localAddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr) if ok { return localAddr.String() } return "" } ================================================ FILE: internal/logger/slog.go ================================================ // Copyright (C) 2025 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package logger import ( "context" "log/slog" "slices" "github.com/rs/zerolog" ) // slogAdapter is an adapter for slog.Handler type slogAdapter struct { sender string attrs []slog.Attr } // NewSlogAdapter creates a slog.Handler adapter func NewSlogAdapter(sender string, attrs []slog.Attr) *slogAdapter { return &slogAdapter{ sender: sender, attrs: attrs, } } func (l *slogAdapter) Enabled(ctx context.Context, level slog.Level) bool { // Log level is handled by our implementation return true } func (l *slogAdapter) Handle(ctx context.Context, r slog.Record) error { var ev *zerolog.Event switch r.Level { case slog.LevelDebug: ev = logger.Debug() case slog.LevelInfo: ev = logger.Info() case slog.LevelWarn: ev = logger.Warn() case slog.LevelError: ev = logger.Error() default: ev = logger.Debug() } ev.Timestamp() if l.sender != "" { ev.Str("sender", l.sender) } addSlogAttr := func(a slog.Attr) { if a.Key == "time" { return } ev.Any(a.Key, a.Value.Any()) } for _, a := range l.attrs { addSlogAttr(a) } r.Attrs(func(a slog.Attr) bool { addSlogAttr(a) return true }) ev.Msg(r.Message) return nil } func (l *slogAdapter) WithAttrs(attrs []slog.Attr) slog.Handler { newHandler := *l newHandler.attrs = slices.Concat(l.attrs, attrs) return &newHandler } func (l *slogAdapter) WithGroup(name string) slog.Handler { newHandler := *l if name != "" { newHandler.sender = name } return &newHandler } ================================================ FILE: internal/logger/sync_wrapper.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package logger import ( "os" "sync" ) type logSyncWrapper struct { sync.Mutex output *os.File } func (l *logSyncWrapper) Write(b []byte) (n int, err error) { l.Lock() defer l.Unlock() return l.output.Write(b) } ================================================ FILE: internal/metric/metric.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !nometrics // Package metric provides Prometheus metrics support package metric import ( "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( loginMethodPublicKey = "publickey" loginMethodKeyboardInteractive = "keyboard-interactive" loginMethodKeyAndPassword = "publickey+password" loginMethodKeyAndKeyboardInt = "publickey+keyboard-interactive" loginMethodTLSCertificate = "TLSCertificate" loginMethodTLSCertificateAndPwd = "TLSCertificate+password" loginMethodIDP = "IDP" ) func init() { version.AddFeature("+metrics") } var ( // dataproviderAvailability is the metric that reports the availability for the configured data provider dataproviderAvailability = promauto.NewGauge(prometheus.GaugeOpts{ Name: "sftpgo_dataprovider_availability", Help: "Availability for the configured data provider, 1 means OK, 0 KO", }) // activeConnections is the metric that reports the total number of active connections activeConnections = promauto.NewGauge(prometheus.GaugeOpts{ Name: "sftpgo_active_connections", Help: "Total number of logged in users", }) // totalUploads is the metric that reports the total number of successful uploads totalUploads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_uploads_total", Help: "The total number of successful uploads", }) // totalDownloads is the metric that reports the total number of successful downloads totalDownloads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_downloads_total", Help: "The total number of successful downloads", }) // totalUploadErrors is the metric that reports the total number of upload errors totalUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_upload_errors_total", Help: "The total number of upload errors", }) // totalDownloadErrors is the metric that reports the total number of download errors totalDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_download_errors_total", Help: "The total number of download errors", }) // totalUploadSize is the metric that reports the total uploads size as bytes totalUploadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_upload_size", Help: "The total upload size as bytes, partial uploads are included", }) // totalDownloadSize is the metric that reports the total downloads size as bytes totalDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_download_size", Help: "The total download size as bytes, partial downloads are included", }) // totalSSHCommands is the metric that reports the total number of executed SSH commands totalSSHCommands = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_ssh_commands_total", Help: "The total number of executed SSH commands", }) // totalSSHCommandErrors is the metric that reports the total number of SSH command errors totalSSHCommandErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_ssh_command_errors_total", Help: "The total number of SSH command errors", }) // totalLoginAttempts is the metric that reports the total number of login attempts totalLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_login_attempts_total", Help: "The total number of login attempts", }) // totalNoAuthTried is te metric that reports the total number of clients disconnected // for inactivity before trying to login totalNoAuthTried = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_no_auth_total", Help: "The total number of clients disconnected for inactivity before trying to login", }) // totalLoginOK is the metric that reports the total number of successful logins totalLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_login_ok_total", Help: "The total number of successful logins", }) // totalLoginFailed is the metric that reports the total number of failed logins totalLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_login_ko_total", Help: "The total number of failed logins", }) // totalPasswordLoginAttempts is the metric that reports the total number of login attempts // using a password totalPasswordLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_password_login_attempts_total", Help: "The total number of login attempts using a password", }) // totalPasswordLoginOK is the metric that reports the total number of successful logins // using a password totalPasswordLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_password_login_ok_total", Help: "The total number of successful logins using a password", }) // totalPasswordLoginFailed is the metric that reports the total number of failed logins // using a password totalPasswordLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_password_login_ko_total", Help: "The total number of failed logins using a password", }) // totalKeyLoginAttempts is the metric that reports the total number of login attempts // using a public key totalKeyLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_public_key_login_attempts_total", Help: "The total number of login attempts using a public key", }) // totalKeyLoginOK is the metric that reports the total number of successful logins // using a public key totalKeyLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_public_key_login_ok_total", Help: "The total number of successful logins using a public key", }) // totalKeyLoginFailed is the metric that reports the total number of failed logins // using a public key totalKeyLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_public_key_login_ko_total", Help: "The total number of failed logins using a public key", }) // totalTLSCertLoginAttempts is the metric that reports the total number of login attempts // using a TLS certificate totalTLSCertLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_tls_cert_login_attempts_total", Help: "The total number of login attempts using a TLS certificate", }) // totalTLSCertLoginOK is the metric that reports the total number of successful logins // using a TLS certificate totalTLSCertLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_tls_cert_login_ok_total", Help: "The total number of successful logins using a TLS certificate", }) // totalTLSCertLoginFailed is the metric that reports the total number of failed logins // using a TLS certificate totalTLSCertLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_tls_cert_login_ko_total", Help: "The total number of failed logins using a TLS certificate", }) // totalTLSCertAndPwdLoginAttempts is the metric that reports the total number of login attempts // using a TLS certificate+password totalTLSCertAndPwdLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_tls_cert_and_pwd_login_attempts_total", Help: "The total number of login attempts using a TLS certificate+password", }) // totalTLSCertLoginOK is the metric that reports the total number of successful logins // using a TLS certificate+password totalTLSCertAndPwdLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_tls_cert_and_pwd_login_ok_total", Help: "The total number of successful logins using a TLS certificate+password", }) // totalTLSCertAndPwdLoginFailed is the metric that reports the total number of failed logins // using a TLS certificate+password totalTLSCertAndPwdLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_tls_cert_and_pwd_login_ko_total", Help: "The total number of failed logins using a TLS certificate+password", }) // totalInteractiveLoginAttempts is the metric that reports the total number of login attempts // using keyboard interactive authentication totalInteractiveLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_keyboard_interactive_login_attempts_total", Help: "The total number of login attempts using keyboard interactive authentication", }) // totalInteractiveLoginOK is the metric that reports the total number of successful logins // using keyboard interactive authentication totalInteractiveLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_keyboard_interactive_login_ok_total", Help: "The total number of successful logins using keyboard interactive authentication", }) // totalInteractiveLoginFailed is the metric that reports the total number of failed logins // using keyboard interactive authentication totalInteractiveLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_keyboard_interactive_login_ko_total", Help: "The total number of failed logins using keyboard interactive authentication", }) // totalKeyAndPasswordLoginAttempts is the metric that reports the total number of // login attempts using public key + password multi steps auth totalKeyAndPasswordLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_key_and_password_login_attempts_total", Help: "The total number of login attempts using public key + password", }) // totalKeyAndPasswordLoginOK is the metric that reports the total number of // successful logins using public key + password multi steps auth totalKeyAndPasswordLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_key_and_password_login_ok_total", Help: "The total number of successful logins using public key + password", }) // totalKeyAndPasswordLoginFailed is the metric that reports the total number of // failed logins using public key + password multi steps auth totalKeyAndPasswordLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_key_and_password_login_ko_total", Help: "The total number of failed logins using public key + password", }) // totalKeyAndKeyIntLoginAttempts is the metric that reports the total number of // login attempts using public key + keyboard interactive multi steps auth totalKeyAndKeyIntLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_key_and_keyboard_int_login_attempts_total", Help: "The total number of login attempts using public key + keyboard interactive", }) // totalKeyAndKeyIntLoginOK is the metric that reports the total number of // successful logins using public key + keyboard interactive multi steps auth totalKeyAndKeyIntLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_key_and_keyboard_int_login_ok_total", Help: "The total number of successful logins using public key + keyboard interactive", }) // totalKeyAndKeyIntLoginFailed is the metric that reports the total number of // failed logins using public key + keyboard interactive multi steps auth totalKeyAndKeyIntLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_key_and_keyboard_int_login_ko_total", Help: "The total number of failed logins using public key + keyboard interactive", }) // totalIDPLoginAttempts is the metric that reports the total number of // login attempts using identity providers totalIDPLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_idp_login_attempts_total", Help: "The total number of login attempts using Identity Providers", }) // totalIDPLoginOK is the metric that reports the total number of // successful logins using identity providers totalIDPLoginOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_idp_login_ok_total", Help: "The total number of successful logins using Identity Providers", }) // totalIDPLoginFailed is the metric that reports the total number of // failed logins using identity providers totalIDPLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_idp_login_ko_total", Help: "The total number of failed logins using Identity Providers", }) totalHTTPRequests = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_http_req_total", Help: "The total number of HTTP requests served", }) totalHTTPOK = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_http_req_ok_total", Help: "The total number of HTTP requests served with 2xx status code", }) totalHTTPClientErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_http_client_errors_total", Help: "The total number of HTTP requests served with 4xx status code", }) totalHTTPServerErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_http_server_errors_total", Help: "The total number of HTTP requests served with 5xx status code", }) // totalS3Uploads is the metric that reports the total number of successful S3 uploads totalS3Uploads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_uploads_total", Help: "The total number of successful S3 uploads", }) // totalS3Downloads is the metric that reports the total number of successful S3 downloads totalS3Downloads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_downloads_total", Help: "The total number of successful S3 downloads", }) // totalS3UploadErrors is the metric that reports the total number of S3 upload errors totalS3UploadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_upload_errors_total", Help: "The total number of S3 upload errors", }) // totalS3DownloadErrors is the metric that reports the total number of S3 download errors totalS3DownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_download_errors_total", Help: "The total number of S3 download errors", }) // totalS3UploadSize is the metric that reports the total S3 uploads size as bytes totalS3UploadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_upload_size", Help: "The total S3 upload size as bytes, partial uploads are included", }) // totalS3DownloadSize is the metric that reports the total S3 downloads size as bytes totalS3DownloadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_download_size", Help: "The total S3 download size as bytes, partial downloads are included", }) // totalS3ListObjects is the metric that reports the total successful S3 list objects requests totalS3ListObjects = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_list_objects", Help: "The total number of successful S3 list objects requests", }) // totalS3CopyObject is the metric that reports the total successful S3 copy object requests totalS3CopyObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_copy_object", Help: "The total number of successful S3 copy object requests", }) // totalS3DeleteObject is the metric that reports the total successful S3 delete object requests totalS3DeleteObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_delete_object", Help: "The total number of successful S3 delete object requests", }) // totalS3ListObjectsError is the metric that reports the total S3 list objects errors totalS3ListObjectsErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_list_objects_errors", Help: "The total number of S3 list objects errors", }) // totalS3CopyObjectErrors is the metric that reports the total S3 copy object errors totalS3CopyObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_copy_object_errors", Help: "The total number of S3 copy object errors", }) // totalS3DeleteObjectErrors is the metric that reports the total S3 delete object errors totalS3DeleteObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_delete_object_errors", Help: "The total number of S3 delete object errors", }) // totalS3HeadObject is the metric that reports the total successful S3 head object requests totalS3HeadObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_head_object", Help: "The total number of successful S3 head object requests", }) // totalS3HeadObjectErrors is the metric that reports the total S3 head object errors totalS3HeadObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_s3_head_object_errors", Help: "The total number of S3 head object errors", }) // totalGCSUploads is the metric that reports the total number of successful GCS uploads totalGCSUploads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_uploads_total", Help: "The total number of successful GCS uploads", }) // totalGCSDownloads is the metric that reports the total number of successful GCS downloads totalGCSDownloads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_downloads_total", Help: "The total number of successful GCS downloads", }) // totalGCSUploadErrors is the metric that reports the total number of GCS upload errors totalGCSUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_upload_errors_total", Help: "The total number of GCS upload errors", }) // totalGCSDownloadErrors is the metric that reports the total number of GCS download errors totalGCSDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_download_errors_total", Help: "The total number of GCS download errors", }) // totalGCSUploadSize is the metric that reports the total GCS uploads size as bytes totalGCSUploadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_upload_size", Help: "The total GCS upload size as bytes, partial uploads are included", }) // totalGCSDownloadSize is the metric that reports the total GCS downloads size as bytes totalGCSDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_download_size", Help: "The total GCS download size as bytes, partial downloads are included", }) // totalGCSListObjects is the metric that reports the total successful GCS list objects requests totalGCSListObjects = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_list_objects", Help: "The total number of successful GCS list objects requests", }) // totalGCSCopyObject is the metric that reports the total successful GCS copy object requests totalGCSCopyObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_copy_object", Help: "The total number of successful GCS copy object requests", }) // totalGCSDeleteObject is the metric that reports the total successful GCS delete object requests totalGCSDeleteObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_delete_object", Help: "The total number of successful GCS delete object requests", }) // totalGCSListObjectsError is the metric that reports the total GCS list objects errors totalGCSListObjectsErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_list_objects_errors", Help: "The total number of GCS list objects errors", }) // totalGCSCopyObjectErrors is the metric that reports the total GCS copy object errors totalGCSCopyObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_copy_object_errors", Help: "The total number of GCS copy object errors", }) // totalGCSDeleteObjectErrors is the metric that reports the total GCS delete object errors totalGCSDeleteObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_delete_object_errors", Help: "The total number of GCS delete object errors", }) // totalGCSHeadObject is the metric that reports the total successful GCS head object requests totalGCSHeadObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_head_object", Help: "The total number of successful GCS head object requests", }) // totalGCSHeadObjectErrors is the metric that reports the total GCS head object errors totalGCSHeadObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_gcs_head_object_errors", Help: "The total number of GCS head object errors", }) // totalAZUploads is the metric that reports the total number of successful Azure uploads totalAZUploads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_uploads_total", Help: "The total number of successful Azure uploads", }) // totalAZDownloads is the metric that reports the total number of successful Azure downloads totalAZDownloads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_downloads_total", Help: "The total number of successful Azure downloads", }) // totalAZUploadErrors is the metric that reports the total number of Azure upload errors totalAZUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_upload_errors_total", Help: "The total number of Azure upload errors", }) // totalAZDownloadErrors is the metric that reports the total number of Azure download errors totalAZDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_download_errors_total", Help: "The total number of Azure download errors", }) // totalAZUploadSize is the metric that reports the total Azure uploads size as bytes totalAZUploadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_upload_size", Help: "The total Azure upload size as bytes, partial uploads are included", }) // totalAZDownloadSize is the metric that reports the total Azure downloads size as bytes totalAZDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_download_size", Help: "The total Azure download size as bytes, partial downloads are included", }) // totalAZListObjects is the metric that reports the total successful Azure list objects requests totalAZListObjects = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_list_objects", Help: "The total number of successful Azure list objects requests", }) // totalAZCopyObject is the metric that reports the total successful Azure copy object requests totalAZCopyObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_copy_object", Help: "The total number of successful Azure copy object requests", }) // totalAZDeleteObject is the metric that reports the total successful Azure delete object requests totalAZDeleteObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_delete_object", Help: "The total number of successful Azure delete object requests", }) // totalAZListObjectsError is the metric that reports the total Azure list objects errors totalAZListObjectsErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_list_objects_errors", Help: "The total number of Azure list objects errors", }) // totalAZCopyObjectErrors is the metric that reports the total Azure copy object errors totalAZCopyObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_copy_object_errors", Help: "The total number of Azure copy object errors", }) // totalAZDeleteObjectErrors is the metric that reports the total Azure delete object errors totalAZDeleteObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_delete_object_errors", Help: "The total number of Azure delete object errors", }) // totalAZHeadObject is the metric that reports the total successful Azure head object requests totalAZHeadObject = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_head_object", Help: "The total number of successful Azure head object requests", }) // totalAZHeadObjectErrors is the metric that reports the total Azure head object errors totalAZHeadObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_az_head_object_errors", Help: "The total number of Azure head object errors", }) // totalSFTPFsUploads is the metric that reports the total number of successful SFTPFs uploads totalSFTPFsUploads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_sftpfs_uploads_total", Help: "The total number of successful SFTPFs uploads", }) // totalSFTPFsDownloads is the metric that reports the total number of successful SFTPFs downloads totalSFTPFsDownloads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_sftpfs_downloads_total", Help: "The total number of successful SFTPFs downloads", }) // totalSFTPFsUploadErrors is the metric that reports the total number of SFTPFs upload errors totalSFTPFsUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_sftpfs_upload_errors_total", Help: "The total number of SFTPFs upload errors", }) // totalSFTPFsDownloadErrors is the metric that reports the total number of SFTPFs download errors totalSFTPFsDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_sftpfs_download_errors_total", Help: "The total number of SFTPFs download errors", }) // totalSFTPFsUploadSize is the metric that reports the total SFTPFs uploads size as bytes totalSFTPFsUploadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_sftpfs_upload_size", Help: "The total SFTPFs upload size as bytes, partial uploads are included", }) // totalSFTPFsDownloadSize is the metric that reports the total SFTPFs downloads size as bytes totalSFTPFsDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_sftpfs_download_size", Help: "The total SFTPFs download size as bytes, partial downloads are included", }) // totalHTTPFsUploads is the metric that reports the total number of successful HTTPFs uploads totalHTTPFsUploads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_httpfs_uploads_total", Help: "The total number of successful HTTPFs uploads", }) // totalHTTPFsDownloads is the metric that reports the total number of successful HTTPFs downloads totalHTTPFsDownloads = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_httpfs_downloads_total", Help: "The total number of successful HTTPFs downloads", }) // totalHTTPFsUploadErrors is the metric that reports the total number of HTTPFs upload errors totalHTTPFsUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_httpfs_upload_errors_total", Help: "The total number of HTTPFs upload errors", }) // totalHTTPFsDownloadErrors is the metric that reports the total number of HTTPFs download errors totalHTTPFsDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_httpfs_download_errors_total", Help: "The total number of HTTPFs download errors", }) // totalHTTPFsUploadSize is the metric that reports the total HTTPFs uploads size as bytes totalHTTPFsUploadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_httpfs_upload_size", Help: "The total HTTPFs upload size as bytes, partial uploads are included", }) // totalHTTPFsDownloadSize is the metric that reports the total HTTPFs downloads size as bytes totalHTTPFsDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ Name: "sftpgo_httpfs_download_size", Help: "The total HTTPFs download size as bytes, partial downloads are included", }) ) // AddMetricsEndpoint publishes metrics to the specified endpoint func AddMetricsEndpoint(metricsPath string, handler chi.Router) { handler.Handle(metricsPath, promhttp.Handler()) } // TransferCompleted updates metrics after an upload or a download func TransferCompleted(bytesSent, bytesReceived int64, transferKind int, err error, isSFTPFs bool) { if transferKind == 0 { // upload if err == nil { totalUploads.Inc() } else { totalUploadErrors.Inc() } } else { // download if err == nil { totalDownloads.Inc() } else { totalDownloadErrors.Inc() } } if bytesReceived > 0 { totalUploadSize.Add(float64(bytesReceived)) } if bytesSent > 0 { totalDownloadSize.Add(float64(bytesSent)) } if isSFTPFs { sftpFsTransferCompleted(bytesSent, bytesReceived, transferKind, err) } } // S3TransferCompleted updates metrics after an S3 upload or a download func S3TransferCompleted(bytes int64, transferKind int, err error) { if transferKind == 0 { // upload if err == nil { totalS3Uploads.Inc() } else { totalS3UploadErrors.Inc() } totalS3UploadSize.Add(float64(bytes)) } else { // download if err == nil { totalS3Downloads.Inc() } else { totalS3DownloadErrors.Inc() } totalS3DownloadSize.Add(float64(bytes)) } } // S3ListObjectsCompleted updates metrics after an S3 list objects request terminates func S3ListObjectsCompleted(err error) { if err == nil { totalS3ListObjects.Inc() } else { totalS3ListObjectsErrors.Inc() } } // S3CopyObjectCompleted updates metrics after an S3 copy object request terminates func S3CopyObjectCompleted(err error) { if err == nil { totalS3CopyObject.Inc() } else { totalS3CopyObjectErrors.Inc() } } // S3DeleteObjectCompleted updates metrics after an S3 delete object request terminates func S3DeleteObjectCompleted(err error) { if err == nil { totalS3DeleteObject.Inc() } else { totalS3DeleteObjectErrors.Inc() } } // S3HeadObjectCompleted updates metrics after a S3 head object request terminates func S3HeadObjectCompleted(err error) { if err == nil { totalS3HeadObject.Inc() } else { totalS3HeadObjectErrors.Inc() } } // GCSTransferCompleted updates metrics after a GCS upload or a download func GCSTransferCompleted(bytes int64, transferKind int, err error) { if transferKind == 0 { // upload if err == nil { totalGCSUploads.Inc() } else { totalGCSUploadErrors.Inc() } totalGCSUploadSize.Add(float64(bytes)) } else { // download if err == nil { totalGCSDownloads.Inc() } else { totalGCSDownloadErrors.Inc() } totalGCSDownloadSize.Add(float64(bytes)) } } // GCSListObjectsCompleted updates metrics after a GCS list objects request terminates func GCSListObjectsCompleted(err error) { if err == nil { totalGCSListObjects.Inc() } else { totalGCSListObjectsErrors.Inc() } } // GCSCopyObjectCompleted updates metrics after a GCS copy object request terminates func GCSCopyObjectCompleted(err error) { if err == nil { totalGCSCopyObject.Inc() } else { totalGCSCopyObjectErrors.Inc() } } // GCSDeleteObjectCompleted updates metrics after a GCS delete object request terminates func GCSDeleteObjectCompleted(err error) { if err == nil { totalGCSDeleteObject.Inc() } else { totalGCSDeleteObjectErrors.Inc() } } // GCSHeadObjectCompleted updates metrics after a GCS head object request terminates func GCSHeadObjectCompleted(err error) { if err == nil { totalGCSHeadObject.Inc() } else { totalGCSHeadObjectErrors.Inc() } } // AZTransferCompleted updates metrics after a Azure upload or a download func AZTransferCompleted(bytes int64, transferKind int, err error) { if transferKind == 0 { // upload if err == nil { totalAZUploads.Inc() } else { totalAZUploadErrors.Inc() } totalAZUploadSize.Add(float64(bytes)) } else { // download if err == nil { totalAZDownloads.Inc() } else { totalAZDownloadErrors.Inc() } totalAZDownloadSize.Add(float64(bytes)) } } // AZListObjectsCompleted updates metrics after a Azure list objects request terminates func AZListObjectsCompleted(err error) { if err == nil { totalAZListObjects.Inc() } else { totalAZListObjectsErrors.Inc() } } // AZCopyObjectCompleted updates metrics after a Azure copy object request terminates func AZCopyObjectCompleted(err error) { if err == nil { totalAZCopyObject.Inc() } else { totalAZCopyObjectErrors.Inc() } } // AZDeleteObjectCompleted updates metrics after a Azure delete object request terminates func AZDeleteObjectCompleted(err error) { if err == nil { totalAZDeleteObject.Inc() } else { totalAZDeleteObjectErrors.Inc() } } // AZHeadObjectCompleted updates metrics after a Azure head object request terminates func AZHeadObjectCompleted(err error) { if err == nil { totalAZHeadObject.Inc() } else { totalAZHeadObjectErrors.Inc() } } // sftpFsTransferCompleted updates metrics after an SFTPFs upload or a download func sftpFsTransferCompleted(bytesSent, bytesReceived int64, transferKind int, err error) { if transferKind == 0 { // upload if err == nil { totalSFTPFsUploads.Inc() } else { totalSFTPFsUploadErrors.Inc() } } else { // download if err == nil { totalSFTPFsDownloads.Inc() } else { totalSFTPFsDownloadErrors.Inc() } } if bytesReceived > 0 { totalSFTPFsUploadSize.Add(float64(bytesReceived)) } if bytesSent > 0 { totalSFTPFsDownloadSize.Add(float64(bytesSent)) } } // HTTPFsTransferCompleted updates metrics after an HTTPFs upload or a download func HTTPFsTransferCompleted(bytes int64, transferKind int, err error) { if transferKind == 0 { // upload if err == nil { totalHTTPFsUploads.Inc() } else { totalHTTPFsUploadErrors.Inc() } totalHTTPFsUploadSize.Add(float64(bytes)) } else { // download if err == nil { totalHTTPFsDownloads.Inc() } else { totalHTTPFsDownloadErrors.Inc() } totalHTTPFsDownloadSize.Add(float64(bytes)) } } // SSHCommandCompleted update metrics after an SSH command terminates func SSHCommandCompleted(err error) { if err == nil { totalSSHCommands.Inc() } else { totalSSHCommandErrors.Inc() } } // UpdateDataProviderAvailability updates the metric for the data provider availability func UpdateDataProviderAvailability(err error) { if err == nil { dataproviderAvailability.Set(1) } else { dataproviderAvailability.Set(0) } } // AddLoginAttempt increments the metrics for login attempts func AddLoginAttempt(authMethod string) { totalLoginAttempts.Inc() switch authMethod { case loginMethodPublicKey: totalKeyLoginAttempts.Inc() case loginMethodKeyboardInteractive: totalInteractiveLoginAttempts.Inc() case loginMethodKeyAndPassword: totalKeyAndPasswordLoginAttempts.Inc() case loginMethodKeyAndKeyboardInt: totalKeyAndKeyIntLoginAttempts.Inc() case loginMethodTLSCertificate: totalTLSCertLoginAttempts.Inc() case loginMethodTLSCertificateAndPwd: totalTLSCertAndPwdLoginAttempts.Inc() case loginMethodIDP: totalIDPLoginAttempts.Inc() default: totalPasswordLoginAttempts.Inc() } } func incLoginOK(authMethod string) { totalLoginOK.Inc() switch authMethod { case loginMethodPublicKey: totalKeyLoginOK.Inc() case loginMethodKeyboardInteractive: totalInteractiveLoginOK.Inc() case loginMethodKeyAndPassword: totalKeyAndPasswordLoginOK.Inc() case loginMethodKeyAndKeyboardInt: totalKeyAndKeyIntLoginOK.Inc() case loginMethodTLSCertificate: totalTLSCertLoginOK.Inc() case loginMethodTLSCertificateAndPwd: totalTLSCertAndPwdLoginOK.Inc() case loginMethodIDP: totalIDPLoginOK.Inc() default: totalPasswordLoginOK.Inc() } } func incLoginFailed(authMethod string) { totalLoginFailed.Inc() switch authMethod { case loginMethodPublicKey: totalKeyLoginFailed.Inc() case loginMethodKeyboardInteractive: totalInteractiveLoginFailed.Inc() case loginMethodKeyAndPassword: totalKeyAndPasswordLoginFailed.Inc() case loginMethodKeyAndKeyboardInt: totalKeyAndKeyIntLoginFailed.Inc() case loginMethodTLSCertificate: totalTLSCertLoginFailed.Inc() case loginMethodTLSCertificateAndPwd: totalTLSCertAndPwdLoginFailed.Inc() case loginMethodIDP: totalIDPLoginFailed.Inc() default: totalPasswordLoginFailed.Inc() } } // AddLoginResult increments the metrics for login results func AddLoginResult(authMethod string, err error) { if err == nil { incLoginOK(authMethod) } else { incLoginFailed(authMethod) } } // AddNoAuthTried increments the metric for clients disconnected // for inactivity before trying to login func AddNoAuthTried() { totalNoAuthTried.Inc() } // HTTPRequestServed increments the metrics for HTTP requests func HTTPRequestServed(status int) { totalHTTPRequests.Inc() if status >= 200 && status < 300 { totalHTTPOK.Inc() } else if status >= 400 && status < 500 { totalHTTPClientErrors.Inc() } else if status >= 500 { totalHTTPServerErrors.Inc() } } // UpdateActiveConnectionsSize sets the metric for active connections func UpdateActiveConnectionsSize(size int) { activeConnections.Set(float64(size)) } ================================================ FILE: internal/metric/metric_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build nometrics package metric import ( "github.com/go-chi/chi/v5" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-metrics") } // AddMetricsEndpoint publishes metrics to the specified endpoint func AddMetricsEndpoint(_ string, _ chi.Router) {} // TransferCompleted updates metrics after an upload or a download func TransferCompleted(_, _ int64, _ int, _ error, _ bool) {} // S3TransferCompleted updates metrics after an S3 upload or a download func S3TransferCompleted(_ int64, _ int, _ error) {} // S3ListObjectsCompleted updates metrics after an S3 list objects request terminates func S3ListObjectsCompleted(_ error) {} // S3CopyObjectCompleted updates metrics after an S3 copy object request terminates func S3CopyObjectCompleted(_ error) {} // S3DeleteObjectCompleted updates metrics after an S3 delete object request terminates func S3DeleteObjectCompleted(_ error) {} // S3HeadBucketCompleted updates metrics after an S3 head bucket request terminates func S3HeadBucketCompleted(_ error) {} // GCSTransferCompleted updates metrics after a GCS upload or a download func GCSTransferCompleted(_ int64, _ int, _ error) {} // GCSListObjectsCompleted updates metrics after a GCS list objects request terminates func GCSListObjectsCompleted(_ error) {} // GCSCopyObjectCompleted updates metrics after a GCS copy object request terminates func GCSCopyObjectCompleted(_ error) {} // GCSDeleteObjectCompleted updates metrics after a GCS delete object request terminates func GCSDeleteObjectCompleted(_ error) {} // GCSHeadBucketCompleted updates metrics after a GCS head bucket request terminates func GCSHeadBucketCompleted(_ error) {} // HTTPFsTransferCompleted updates metrics after an HTTPFs upload or a download func HTTPFsTransferCompleted(_ int64, _ int, _ error) {} // SSHCommandCompleted update metrics after an SSH command terminates func SSHCommandCompleted(_ error) {} // UpdateDataProviderAvailability updates the metric for the data provider availability func UpdateDataProviderAvailability(_ error) {} // AddLoginAttempt increments the metrics for login attempts func AddLoginAttempt(_ string) {} // AddLoginResult increments the metrics for login results func AddLoginResult(_ string, _ error) {} // AddNoAuthTried increments the metric for clients disconnected // for inactivity before trying to login func AddNoAuthTried() {} // HTTPRequestServed increments the metrics for HTTP requests func HTTPRequestServed(_ int) {} // UpdateActiveConnectionsSize sets the metric for active connections func UpdateActiveConnectionsSize(_ int) {} ================================================ FILE: internal/mfa/mfa.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package mfa provides supports for Multi-Factor authentication modules package mfa import ( "bytes" "fmt" "image/png" "time" "github.com/pquerna/otp" ) var ( totpConfigs []*TOTPConfig serviceStatus ServiceStatus ) // ServiceStatus defines the service status type ServiceStatus struct { IsActive bool `json:"is_active"` TOTPConfigs []TOTPConfig `json:"totp_configs"` } // GetStatus returns the service status func GetStatus() ServiceStatus { return serviceStatus } // Config defines configuration parameters for Multi-Factor authentication modules type Config struct { // Time-based one time passwords configurations TOTP []TOTPConfig `json:"totp" mapstructure:"totp"` } // Initialize configures the MFA support func (c *Config) Initialize() error { totpConfigs = nil serviceStatus.IsActive = false serviceStatus.TOTPConfigs = nil totp := make(map[string]bool) for _, totpConfig := range c.TOTP { totpConfig := totpConfig //pin if err := totpConfig.validate(); err != nil { totpConfigs = nil return fmt.Errorf("invalid TOTP config %+v: %v", totpConfig, err) } if _, ok := totp[totpConfig.Name]; ok { totpConfigs = nil return fmt.Errorf("totp: duplicate configuration name %q", totpConfig.Name) } totp[totpConfig.Name] = true totpConfigs = append(totpConfigs, &totpConfig) serviceStatus.IsActive = true serviceStatus.TOTPConfigs = append(serviceStatus.TOTPConfigs, totpConfig) } startCleanupTicker(2 * time.Minute) return nil } // GetAvailableTOTPConfigs returns the available TOTP configs func GetAvailableTOTPConfigs() []*TOTPConfig { return totpConfigs } // GetAvailableTOTPConfigNames returns the available TOTP config names func GetAvailableTOTPConfigNames() []string { var result []string for _, c := range totpConfigs { result = append(result, c.Name) } return result } // ValidateTOTPPasscode validates a TOTP passcode using the given secret and configName func ValidateTOTPPasscode(configName, passcode, secret string) (bool, error) { for _, config := range totpConfigs { if config.Name == configName { return config.validatePasscode(passcode, secret) } } return false, fmt.Errorf("totp: no configuration %q", configName) } // GenerateTOTPSecret generates a new TOTP secret and QR code for the given username // using the configuration with configName func GenerateTOTPSecret(configName, username string) (string, *otp.Key, []byte, error) { for _, config := range totpConfigs { if config.Name == configName { key, qrCode, err := config.generate(username, 200, 200) return configName, key, qrCode, err } } return "", nil, nil, fmt.Errorf("totp: no configuration %q", configName) } // GenerateQRCodeFromURL generates a QR code from a TOTP URL func GenerateQRCodeFromURL(url string, width, height int) ([]byte, error) { key, err := otp.NewKeyFromURL(url) if err != nil { return nil, err } var buf bytes.Buffer img, err := key.Image(width, height) if err != nil { return nil, err } err = png.Encode(&buf, img) return buf.Bytes(), err } // the ticker cannot be started/stopped from multiple goroutines func startCleanupTicker(duration time.Duration) { stopCleanupTicker() cleanupTicker = time.NewTicker(duration) cleanupDone = make(chan bool) go func() { for { select { case <-cleanupDone: return case <-cleanupTicker.C: cleanupUsedPasscodes() } } }() } func stopCleanupTicker() { if cleanupTicker != nil { cleanupTicker.Stop() cleanupDone <- true cleanupTicker = nil } } ================================================ FILE: internal/mfa/mfa_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package mfa import ( "testing" "time" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMFAConfig(t *testing.T) { config := Config{ TOTP: []TOTPConfig{ {}, }, } configName1 := "config1" configName2 := "config2" configName3 := "config3" err := config.Initialize() assert.Error(t, err) config.TOTP[0].Name = configName1 err = config.Initialize() assert.Error(t, err) config.TOTP[0].Issuer = "issuer" err = config.Initialize() assert.Error(t, err) config.TOTP[0].Algo = TOTPAlgoSHA1 err = config.Initialize() assert.NoError(t, err) config.TOTP = append(config.TOTP, TOTPConfig{ Name: configName1, Issuer: "SFTPGo", Algo: TOTPAlgoSHA512, }) err = config.Initialize() assert.Error(t, err) config.TOTP[1].Name = configName2 err = config.Initialize() assert.NoError(t, err) assert.Len(t, GetAvailableTOTPConfigs(), 2) assert.Len(t, GetAvailableTOTPConfigNames(), 2) config.TOTP = append(config.TOTP, TOTPConfig{ Name: configName3, Issuer: "SFTPGo", Algo: TOTPAlgoSHA256, }) err = config.Initialize() assert.NoError(t, err) assert.Len(t, GetAvailableTOTPConfigs(), 3) if assert.Len(t, GetAvailableTOTPConfigNames(), 3) { assert.Contains(t, GetAvailableTOTPConfigNames(), configName1) assert.Contains(t, GetAvailableTOTPConfigNames(), configName2) assert.Contains(t, GetAvailableTOTPConfigNames(), configName3) } status := GetStatus() assert.True(t, status.IsActive) if assert.Len(t, status.TOTPConfigs, 3) { assert.Equal(t, configName1, status.TOTPConfigs[0].Name) assert.Equal(t, configName2, status.TOTPConfigs[1].Name) assert.Equal(t, configName3, status.TOTPConfigs[2].Name) } // now generate some secrets and validate some passcodes _, _, _, err = GenerateTOTPSecret("", "") //nolint:dogsled assert.Error(t, err) match, err := ValidateTOTPPasscode("", "", "") assert.Error(t, err) assert.False(t, match) cfgName, key, _, err := GenerateTOTPSecret(configName1, "user1") assert.NoError(t, err) assert.NotEmpty(t, key.Secret()) assert.Equal(t, configName1, cfgName) passcode, err := generatePasscode(key.Secret(), otp.AlgorithmSHA1) assert.NoError(t, err) match, err = ValidateTOTPPasscode(configName1, passcode, key.Secret()) assert.NoError(t, err) assert.True(t, match) match, err = ValidateTOTPPasscode(configName1, passcode, key.Secret()) assert.ErrorIs(t, err, errPasscodeUsed) assert.False(t, match) passcode, err = generatePasscode(key.Secret(), otp.AlgorithmSHA256) assert.NoError(t, err) // config1 uses sha1 algo match, err = ValidateTOTPPasscode(configName1, passcode, key.Secret()) assert.NoError(t, err) assert.False(t, match) // config3 use the expected algo match, err = ValidateTOTPPasscode(configName3, passcode, key.Secret()) assert.NoError(t, err) assert.True(t, match) stopCleanupTicker() } func TestGenerateQRCodeFromURL(t *testing.T) { _, err := GenerateQRCodeFromURL("http://foo\x7f.cloud", 200, 200) assert.Error(t, err) config := TOTPConfig{ Name: "config name", Issuer: "SFTPGo", Algo: TOTPAlgoSHA256, } key, qrCode, err := config.generate("a", 150, 150) require.NoError(t, err) qrCode1, err := GenerateQRCodeFromURL(key.URL(), 150, 150) require.NoError(t, err) assert.Equal(t, qrCode, qrCode1) _, err = GenerateQRCodeFromURL(key.URL(), 10, 10) assert.Error(t, err) } func TestCleanupPasscodes(t *testing.T) { usedPasscodes.Store("key", time.Now().Add(-24*time.Hour).UTC()) startCleanupTicker(30 * time.Millisecond) assert.Eventually(t, func() bool { _, ok := usedPasscodes.Load("key") return !ok }, 1000*time.Millisecond, 100*time.Millisecond) stopCleanupTicker() } func TestTOTPGenerateErrors(t *testing.T) { config := TOTPConfig{ Name: "name", Issuer: "", algo: otp.AlgorithmSHA1, } // issuer cannot be empty _, _, err := config.generate("username", 200, 200) //nolint:dogsled assert.Error(t, err) config.Issuer = "issuer" // we cannot encode an image smaller than 45x45 _, _, err = config.generate("username", 30, 30) //nolint:dogsled assert.Error(t, err) } func generatePasscode(secret string, algo otp.Algorithm) (string, error) { return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ Period: 30, Skew: 1, Digits: otp.DigitsSix, Algorithm: algo, }) } ================================================ FILE: internal/mfa/totp.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package mfa import ( "bytes" "errors" "fmt" "image/png" "sync" "time" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" ) // TOTPHMacAlgo is the enumerable for the possible HMAC algorithms for Time-based one time passwords type TOTPHMacAlgo = string // supported TOTP HMAC algorithms const ( TOTPAlgoSHA1 TOTPHMacAlgo = "sha1" TOTPAlgoSHA256 TOTPHMacAlgo = "sha256" TOTPAlgoSHA512 TOTPHMacAlgo = "sha512" ) var ( cleanupTicker *time.Ticker cleanupDone chan bool usedPasscodes sync.Map errPasscodeUsed = errors.New("this passcode was already used") ) // TOTPConfig defines the configuration for a Time-based one time password type TOTPConfig struct { Name string `json:"name" mapstructure:"name"` Issuer string `json:"issuer" mapstructure:"issuer"` Algo TOTPHMacAlgo `json:"algo" mapstructure:"algo"` algo otp.Algorithm } func (c *TOTPConfig) validate() error { if c.Name == "" { return errors.New("totp: name is mandatory") } if c.Issuer == "" { return errors.New("totp: issuer is mandatory") } switch c.Algo { case TOTPAlgoSHA1: c.algo = otp.AlgorithmSHA1 case TOTPAlgoSHA256: c.algo = otp.AlgorithmSHA256 case TOTPAlgoSHA512: c.algo = otp.AlgorithmSHA512 default: return fmt.Errorf("unsupported totp algo %q", c.Algo) } return nil } // validatePasscode validates a TOTP passcode func (c *TOTPConfig) validatePasscode(passcode, secret string) (bool, error) { key := fmt.Sprintf("%v_%v", secret, passcode) if _, ok := usedPasscodes.Load(key); ok { return false, errPasscodeUsed } match, err := totp.ValidateCustom(passcode, secret, time.Now().UTC(), totp.ValidateOpts{ Period: 30, Skew: 1, Digits: otp.DigitsSix, Algorithm: c.algo, }) if match && err == nil { usedPasscodes.Store(key, time.Now().Add(1*time.Minute).UTC()) } return match, err } // generate generates a new TOTP secret and QR code for the given username func (c *TOTPConfig) generate(username string, qrCodeWidth, qrCodeHeight int) (*otp.Key, []byte, error) { key, err := totp.Generate(totp.GenerateOpts{ Issuer: c.Issuer, AccountName: username, Digits: otp.DigitsSix, Algorithm: c.algo, }) if err != nil { return nil, nil, err } var buf bytes.Buffer img, err := key.Image(qrCodeWidth, qrCodeHeight) if err != nil { return nil, nil, err } err = png.Encode(&buf, img) return key, buf.Bytes(), err } func cleanupUsedPasscodes() { usedPasscodes.Range(func(key, value any) bool { exp, ok := value.(time.Time) if !ok || exp.Before(time.Now().UTC()) { usedPasscodes.Delete(key) } return true }) } ================================================ FILE: internal/plugin/auth.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package plugin import ( "errors" "fmt" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" "github.com/sftpgo/sdk/plugin/auth" "github.com/drakkan/sftpgo/v2/internal/logger" ) // Supported auth scopes const ( AuthScopePassword = 1 AuthScopePublicKey = 2 AuthScopeKeyboardInteractive = 4 AuthScopeTLSCertificate = 8 ) // KeyboardAuthRequest defines the request for a keyboard interactive authentication step type KeyboardAuthRequest struct { RequestID string `json:"request_id"` Step int `json:"step"` Username string `json:"username,omitempty"` IP string `json:"ip,omitempty"` Password string `json:"password,omitempty"` Answers []string `json:"answers,omitempty"` Questions []string `json:"questions,omitempty"` } // KeyboardAuthResponse defines the response for a keyboard interactive authentication step type KeyboardAuthResponse struct { Instruction string `json:"instruction"` Questions []string `json:"questions"` Echos []bool `json:"echos"` AuthResult int `json:"auth_result"` CheckPwd int `json:"check_password"` } // Validate returns an error if the KeyboardAuthResponse is invalid func (r *KeyboardAuthResponse) Validate() error { if len(r.Questions) == 0 { err := errors.New("interactive auth error: response does not contain questions") return err } if len(r.Questions) != len(r.Echos) { err := fmt.Errorf("interactive auth error: response questions don't match echos: %v %v", len(r.Questions), len(r.Echos)) return err } return nil } // AuthConfig defines configuration parameters for auth plugins type AuthConfig struct { // Scope defines the scope for the authentication plugin. // - 1 means passwords only // - 2 means public keys only // - 4 means keyboard interactive only // - 8 means TLS certificates only // you can combine the scopes, for example 3 means password and public key, 5 password and keyboard // interactive and so on Scope int `json:"scope" mapstructure:"scope"` } func (c *AuthConfig) validate() error { authScopeMax := AuthScopePassword + AuthScopePublicKey + AuthScopeKeyboardInteractive + AuthScopeTLSCertificate if c.Scope == 0 || c.Scope > authScopeMax { return fmt.Errorf("invalid auth scope: %v", c.Scope) } return nil } type authPlugin struct { config Config service auth.Authenticator client *plugin.Client } func newAuthPlugin(config Config) (*authPlugin, error) { p := &authPlugin{ config: config, } if err := p.initialize(); err != nil { logger.Warn(logSender, "", "unable to create auth plugin: %v, config %+v", err, config) return nil, err } return p, nil } func (p *authPlugin) initialize() error { killProcess(p.config.Cmd) logger.Debug(logSender, "", "create new auth plugin %q", p.config.Cmd) if err := p.config.AuthOptions.validate(); err != nil { return fmt.Errorf("invalid options for auth plugin %q: %v", p.config.Cmd, err) } secureConfig, err := p.config.getSecureConfig() if err != nil { return err } client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: auth.Handshake, Plugins: auth.PluginMap, Cmd: p.config.getCommand(), SkipHostEnv: true, AllowedProtocols: []plugin.Protocol{ plugin.ProtocolGRPC, }, AutoMTLS: p.config.AutoMTLS, SecureConfig: secureConfig, Managed: false, Logger: &logger.HCLogAdapter{ Logger: hclog.New(&hclog.LoggerOptions{ Name: fmt.Sprintf("%v.%v", logSender, auth.PluginName), Level: pluginsLogLevel, DisableTime: true, }), }, }) rpcClient, err := client.Client() if err != nil { logger.Debug(logSender, "", "unable to get rpc client for auth plugin %q: %v", p.config.Cmd, err) return err } raw, err := rpcClient.Dispense(auth.PluginName) if err != nil { logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", auth.PluginName, p.config.Cmd, err) return err } p.service = raw.(auth.Authenticator) p.client = client return nil } func (p *authPlugin) exited() bool { return p.client.Exited() } func (p *authPlugin) cleanup() { p.client.Kill() } func (p *authPlugin) checkUserAndPass(username, password, ip, protocol string, userAsJSON []byte) ([]byte, error) { return p.service.CheckUserAndPass(username, password, ip, protocol, userAsJSON) } func (p *authPlugin) checkUserAndTLSCertificate(username, tlsCert, ip, protocol string, userAsJSON []byte) ([]byte, error) { return p.service.CheckUserAndTLSCert(username, tlsCert, ip, protocol, userAsJSON) } func (p *authPlugin) checkUserAndPublicKey(username, pubKey, ip, protocol string, userAsJSON []byte) ([]byte, error) { return p.service.CheckUserAndPublicKey(username, pubKey, ip, protocol, userAsJSON) } func (p *authPlugin) checkUserAndKeyboardInteractive(username, ip, protocol string, userAsJSON []byte) ([]byte, error) { return p.service.CheckUserAndKeyboardInteractive(username, ip, protocol, userAsJSON) } func (p *authPlugin) sendKeyboardIteractiveRequest(req *KeyboardAuthRequest) (*KeyboardAuthResponse, error) { instructions, questions, echos, authResult, checkPassword, err := p.service.SendKeyboardAuthRequest( req.RequestID, req.Username, req.Password, req.IP, req.Answers, req.Questions, int32(req.Step)) if err != nil { return nil, err } return &KeyboardAuthResponse{ Instruction: instructions, Questions: questions, Echos: echos, AuthResult: authResult, CheckPwd: checkPassword, }, nil } ================================================ FILE: internal/plugin/ipfilter.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package plugin import ( "fmt" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" "github.com/sftpgo/sdk/plugin/ipfilter" "github.com/drakkan/sftpgo/v2/internal/logger" ) type ipFilterPlugin struct { config Config filter ipfilter.Filter client *plugin.Client } func newIPFilterPlugin(config Config) (*ipFilterPlugin, error) { p := &ipFilterPlugin{ config: config, } if err := p.initialize(); err != nil { logger.Warn(logSender, "", "unable to create IP filter plugin: %v, config %+v", err, config) return nil, err } return p, nil } func (p *ipFilterPlugin) exited() bool { return p.client.Exited() } func (p *ipFilterPlugin) cleanup() { p.client.Kill() } func (p *ipFilterPlugin) initialize() error { logger.Debug(logSender, "", "create new IP filter plugin %q", p.config.Cmd) killProcess(p.config.Cmd) secureConfig, err := p.config.getSecureConfig() if err != nil { return err } client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: ipfilter.Handshake, Plugins: ipfilter.PluginMap, Cmd: p.config.getCommand(), SkipHostEnv: true, AllowedProtocols: []plugin.Protocol{ plugin.ProtocolGRPC, }, AutoMTLS: p.config.AutoMTLS, SecureConfig: secureConfig, Managed: false, Logger: &logger.HCLogAdapter{ Logger: hclog.New(&hclog.LoggerOptions{ Name: fmt.Sprintf("%v.%v", logSender, ipfilter.PluginName), Level: pluginsLogLevel, DisableTime: true, }), }, }) rpcClient, err := client.Client() if err != nil { logger.Debug(logSender, "", "unable to get rpc client for plugin %q: %v", p.config.Cmd, err) return err } raw, err := rpcClient.Dispense(ipfilter.PluginName) if err != nil { logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", ipfilter.PluginName, p.config.Cmd, err) return err } p.client = client p.filter = raw.(ipfilter.Filter) return nil } ================================================ FILE: internal/plugin/kms.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package plugin import ( "fmt" "path/filepath" "slices" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" sdkkms "github.com/sftpgo/sdk/kms" kmsplugin "github.com/sftpgo/sdk/plugin/kms" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" ) var ( validKMSSchemes = []string{sdkkms.SchemeAWS, sdkkms.SchemeGCP, sdkkms.SchemeVaultTransit, sdkkms.SchemeAzureKeyVault, sdkkms.SchemeOracleKeyVault} validKMSEncryptedStatuses = []string{sdkkms.SecretStatusVaultTransit, sdkkms.SecretStatusAWS, sdkkms.SecretStatusGCP, sdkkms.SecretStatusAzureKeyVault, sdkkms.SecretStatusOracleKeyVault} ) // KMSConfig defines configuration parameters for kms plugins type KMSConfig struct { Scheme string `json:"scheme" mapstructure:"scheme"` EncryptedStatus string `json:"encrypted_status" mapstructure:"encrypted_status"` } func (c *KMSConfig) validate() error { if !slices.Contains(validKMSSchemes, c.Scheme) { return fmt.Errorf("invalid kms scheme: %v", c.Scheme) } if !slices.Contains(validKMSEncryptedStatuses, c.EncryptedStatus) { return fmt.Errorf("invalid kms encrypted status: %v", c.EncryptedStatus) } return nil } type kmsPlugin struct { config Config service kmsplugin.Service client *plugin.Client } func newKMSPlugin(config Config) (*kmsPlugin, error) { p := &kmsPlugin{ config: config, } if err := p.initialize(); err != nil { logger.Warn(logSender, "", "unable to create kms plugin: %v, config %+v", err, config) return nil, err } return p, nil } func (p *kmsPlugin) initialize() error { killProcess(p.config.Cmd) logger.Debug(logSender, "", "create new kms plugin %q", p.config.Cmd) if err := p.config.KMSOptions.validate(); err != nil { return fmt.Errorf("invalid options for kms plugin %q: %v", p.config.Cmd, err) } secureConfig, err := p.config.getSecureConfig() if err != nil { return err } client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: kmsplugin.Handshake, Plugins: kmsplugin.PluginMap, Cmd: p.config.getCommand(), SkipHostEnv: true, AllowedProtocols: []plugin.Protocol{ plugin.ProtocolGRPC, }, AutoMTLS: p.config.AutoMTLS, SecureConfig: secureConfig, Managed: false, Logger: &logger.HCLogAdapter{ Logger: hclog.New(&hclog.LoggerOptions{ Name: fmt.Sprintf("%v.%v", logSender, kmsplugin.PluginName), Level: pluginsLogLevel, DisableTime: true, }), }, }) rpcClient, err := client.Client() if err != nil { logger.Debug(logSender, "", "unable to get rpc client for kms plugin %q: %v", p.config.Cmd, err) return err } raw, err := rpcClient.Dispense(kmsplugin.PluginName) if err != nil { logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", kmsplugin.PluginName, p.config.Cmd, err) return err } p.client = client p.service = raw.(kmsplugin.Service) return nil } func (p *kmsPlugin) exited() bool { return p.client.Exited() } func (p *kmsPlugin) cleanup() { p.client.Kill() } func (p *kmsPlugin) Encrypt(secret kms.BaseSecret, url string, masterKey string) (string, string, int32, error) { return p.service.Encrypt(secret.Payload, secret.AdditionalData, url, masterKey) } func (p *kmsPlugin) Decrypt(secret kms.BaseSecret, url string, masterKey string) (string, error) { return p.service.Decrypt(secret.Payload, secret.Key, secret.AdditionalData, secret.Mode, url, masterKey) } type kmsPluginSecretProvider struct { kms.BaseSecret URL string MasterKey string config *Config } func (s *kmsPluginSecretProvider) Name() string { return fmt.Sprintf("KMSPlugin_%v_%v_%v", filepath.Base(s.config.Cmd), s.config.KMSOptions.Scheme, s.config.kmsID) } func (s *kmsPluginSecretProvider) IsEncrypted() bool { return s.Status == s.config.KMSOptions.EncryptedStatus } func (s *kmsPluginSecretProvider) Encrypt() error { if s.Status != sdkkms.SecretStatusPlain { return kms.ErrWrongSecretStatus } if s.Payload == "" { return kms.ErrInvalidSecret } payload, key, mode, err := Handler.kmsEncrypt(s.BaseSecret, s.URL, s.MasterKey, s.config.kmsID) if err != nil { return err } s.Status = s.config.KMSOptions.EncryptedStatus s.Payload = payload s.Key = key s.Mode = int(mode) return nil } func (s *kmsPluginSecretProvider) Decrypt() error { if !s.IsEncrypted() { return kms.ErrWrongSecretStatus } payload, err := Handler.kmsDecrypt(s.BaseSecret, s.URL, s.MasterKey, s.config.kmsID) if err != nil { return err } s.Status = sdkkms.SecretStatusPlain s.Payload = payload s.Key = "" s.AdditionalData = "" s.Mode = 0 return nil } func (s *kmsPluginSecretProvider) Clone() kms.SecretProvider { baseSecret := kms.BaseSecret{ Status: s.Status, Payload: s.Payload, Key: s.Key, AdditionalData: s.AdditionalData, Mode: s.Mode, } return s.config.newKMSPluginSecretProvider(baseSecret, s.URL, s.MasterKey) } ================================================ FILE: internal/plugin/notifier.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package plugin import ( "fmt" "slices" "sync" "time" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/logger" ) // NotifierConfig defines configuration parameters for notifiers plugins type NotifierConfig struct { FsEvents []string `json:"fs_events" mapstructure:"fs_events"` ProviderEvents []string `json:"provider_events" mapstructure:"provider_events"` ProviderObjects []string `json:"provider_objects" mapstructure:"provider_objects"` LogEvents []int `json:"log_events" mapstructure:"log_events"` RetryMaxTime int `json:"retry_max_time" mapstructure:"retry_max_time"` RetryQueueMaxSize int `json:"retry_queue_max_size" mapstructure:"retry_queue_max_size"` } func (c *NotifierConfig) hasActions() bool { if len(c.FsEvents) > 0 { return true } if len(c.ProviderEvents) > 0 && len(c.ProviderObjects) > 0 { return true } if len(c.LogEvents) > 0 { return true } return false } type notifierPlugin struct { config Config notifier notifier.Notifier client *plugin.Client mu sync.RWMutex fsEvents []*notifier.FsEvent providerEvents []*notifier.ProviderEvent logEvents []*notifier.LogEvent } func newNotifierPlugin(config Config) (*notifierPlugin, error) { p := ¬ifierPlugin{ config: config, } if err := p.initialize(); err != nil { logger.Warn(logSender, "", "unable to create notifier plugin: %v, config %+v", err, config) return nil, err } return p, nil } func (p *notifierPlugin) exited() bool { return p.client.Exited() } func (p *notifierPlugin) cleanup() { p.client.Kill() } func (p *notifierPlugin) initialize() error { killProcess(p.config.Cmd) logger.Debug(logSender, "", "create new notifier plugin %q", p.config.Cmd) if !p.config.NotifierOptions.hasActions() { return fmt.Errorf("no actions defined for the notifier plugin %q", p.config.Cmd) } secureConfig, err := p.config.getSecureConfig() if err != nil { return err } client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: notifier.Handshake, Plugins: notifier.PluginMap, Cmd: p.config.getCommand(), SkipHostEnv: true, AllowedProtocols: []plugin.Protocol{ plugin.ProtocolGRPC, }, AutoMTLS: p.config.AutoMTLS, SecureConfig: secureConfig, Managed: false, Logger: &logger.HCLogAdapter{ Logger: hclog.New(&hclog.LoggerOptions{ Name: fmt.Sprintf("%s.%s", logSender, notifier.PluginName), Level: pluginsLogLevel, DisableTime: true, }), }, }) rpcClient, err := client.Client() if err != nil { logger.Debug(logSender, "", "unable to get rpc client for plugin %q: %v", p.config.Cmd, err) return err } raw, err := rpcClient.Dispense(notifier.PluginName) if err != nil { logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", notifier.PluginName, p.config.Cmd, err) return err } p.client = client p.notifier = raw.(notifier.Notifier) return nil } func (p *notifierPlugin) queueSize() int { p.mu.RLock() defer p.mu.RUnlock() return len(p.providerEvents) + len(p.fsEvents) + len(p.logEvents) } func (p *notifierPlugin) queueFsEvent(ev *notifier.FsEvent) { p.mu.Lock() defer p.mu.Unlock() p.fsEvents = append(p.fsEvents, ev) } func (p *notifierPlugin) queueProviderEvent(ev *notifier.ProviderEvent) { p.mu.Lock() defer p.mu.Unlock() p.providerEvents = append(p.providerEvents, ev) } func (p *notifierPlugin) queueLogEvent(ev *notifier.LogEvent) { p.mu.Lock() defer p.mu.Unlock() p.logEvents = append(p.logEvents, ev) } func (p *notifierPlugin) canQueueEvent(timestamp int64) bool { if p.config.NotifierOptions.RetryMaxTime == 0 { return false } if time.Now().After(time.Unix(0, timestamp).Add(time.Duration(p.config.NotifierOptions.RetryMaxTime) * time.Second)) { logger.Warn(logSender, "", "dropping too late event for plugin %v, event timestamp: %v", p.config.Cmd, time.Unix(0, timestamp)) return false } if p.config.NotifierOptions.RetryQueueMaxSize > 0 { return p.queueSize() < p.config.NotifierOptions.RetryQueueMaxSize } return true } func (p *notifierPlugin) notifyFsAction(event *notifier.FsEvent) { if !slices.Contains(p.config.NotifierOptions.FsEvents, event.Action) { return } p.sendFsEvent(event) } func (p *notifierPlugin) notifyProviderAction(event *notifier.ProviderEvent, object Renderer) { if !slices.Contains(p.config.NotifierOptions.ProviderEvents, event.Action) || !slices.Contains(p.config.NotifierOptions.ProviderObjects, event.ObjectType) { return } p.sendProviderEvent(event, object) } func (p *notifierPlugin) notifyLogEvent(event *notifier.LogEvent) { p.sendLogEvent(event) } func (p *notifierPlugin) sendFsEvent(ev *notifier.FsEvent) { go func(event *notifier.FsEvent) { Handler.addTask() defer Handler.removeTask() if err := p.notifier.NotifyFsEvent(event); err != nil { logger.Warn(logSender, "", "unable to send fs action notification to plugin %v: %v", p.config.Cmd, err) if p.canQueueEvent(event.Timestamp) { p.queueFsEvent(event) } } }(ev) } func (p *notifierPlugin) sendProviderEvent(ev *notifier.ProviderEvent, object Renderer) { go func(event *notifier.ProviderEvent) { Handler.addTask() defer Handler.removeTask() if object != nil { objectAsJSON, err := object.RenderAsJSON(event.Action != "delete") if err != nil { logger.Error(logSender, "", "unable to render user as json for action %q: %v", event.Action, err) } else { event.ObjectData = objectAsJSON } } if err := p.notifier.NotifyProviderEvent(event); err != nil { logger.Warn(logSender, "", "unable to send user action notification to plugin %v: %v", p.config.Cmd, err) if p.canQueueEvent(event.Timestamp) { p.queueProviderEvent(event) } } }(ev) } func (p *notifierPlugin) sendLogEvent(ev *notifier.LogEvent) { go func(event *notifier.LogEvent) { Handler.addTask() defer Handler.removeTask() if err := p.notifier.NotifyLogEvent(event); err != nil { logger.Warn(logSender, "", "unable to send log event to plugin %v: %v", p.config.Cmd, err) if p.canQueueEvent(event.Timestamp) { p.queueLogEvent(event) } } }(ev) } func (p *notifierPlugin) sendQueuedEvents() { queueSize := p.queueSize() if queueSize == 0 { return } p.mu.Lock() defer p.mu.Unlock() logger.Debug(logSender, "", "send queued events for notifier %q, events size: %v", p.config.Cmd, queueSize) for _, ev := range p.fsEvents { p.sendFsEvent(ev) } p.fsEvents = nil for _, ev := range p.providerEvents { p.sendProviderEvent(ev, nil) } p.providerEvents = nil for _, ev := range p.logEvents { p.sendLogEvent(ev) } p.logEvents = nil logger.Debug(logSender, "", "%d queued events sent for notifier %q,", queueSize, p.config.Cmd) } ================================================ FILE: internal/plugin/plugin.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package plugin provides support for the SFTPGo plugin system package plugin import ( "crypto/sha256" "crypto/x509" "encoding/hex" "errors" "fmt" "os" "os/exec" "path/filepath" "slices" "strings" "sync" "sync/atomic" "time" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" "github.com/sftpgo/sdk/plugin/auth" "github.com/sftpgo/sdk/plugin/eventsearcher" "github.com/sftpgo/sdk/plugin/ipfilter" kmsplugin "github.com/sftpgo/sdk/plugin/kms" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( logSender = "plugins" ) var ( // Handler defines the plugins manager Handler Manager pluginsLogLevel = hclog.Debug // ErrNoSearcher defines the error to return for events searches if no plugin is configured ErrNoSearcher = errors.New("no events searcher plugin defined") ) // Renderer defines the interface for generic objects rendering type Renderer interface { RenderAsJSON(reload bool) ([]byte, error) } // Config defines a plugin configuration type Config struct { // Plugin type Type string `json:"type" mapstructure:"type"` // NotifierOptions defines options for notifiers plugins NotifierOptions NotifierConfig `json:"notifier_options" mapstructure:"notifier_options"` // KMSOptions defines options for a KMS plugin KMSOptions KMSConfig `json:"kms_options" mapstructure:"kms_options"` // AuthOptions defines options for authentication plugins AuthOptions AuthConfig `json:"auth_options" mapstructure:"auth_options"` // Path to the plugin executable Cmd string `json:"cmd" mapstructure:"cmd"` // Args to pass to the plugin executable Args []string `json:"args" mapstructure:"args"` // SHA256 checksum for the plugin executable. // If not empty it will be used to verify the integrity of the executable SHA256Sum string `json:"sha256sum" mapstructure:"sha256sum"` // If enabled the client and the server automatically negotiate mTLS for // transport authentication. This ensures that only the original client will // be allowed to connect to the server, and all other connections will be // rejected. The client will also refuse to connect to any server that isn't // the original instance started by the client. AutoMTLS bool `json:"auto_mtls" mapstructure:"auto_mtls"` // EnvPrefix defines the prefix for env vars to pass from the SFTPGo process // environment to the plugin. Set to "none" to not pass any environment // variable, set to "*" to pass all environment variables. If empty, the // prefix is returned as the plugin name in uppercase with "-" replaced with // "_" and a trailing "_". For example if the plugin name is // sftpgo-plugin-eventsearch the prefix will be SFTPGO_PLUGIN_EVENTSEARCH_ EnvPrefix string `json:"env_prefix" mapstructure:"env_prefix"` // Additional environment variable names to pass from the SFTPGo process // environment to the plugin. EnvVars []string `json:"env_vars" mapstructure:"env_vars"` // unique identifier for kms plugins kmsID int } func (c *Config) getSecureConfig() (*plugin.SecureConfig, error) { if c.SHA256Sum != "" { checksum, err := hex.DecodeString(c.SHA256Sum) if err != nil { return nil, fmt.Errorf("invalid sha256 hash %q: %w", c.SHA256Sum, err) } return &plugin.SecureConfig{ Checksum: checksum, Hash: sha256.New(), }, nil } return nil, nil } func (c *Config) getEnvVarPrefix() string { if c.EnvPrefix == "none" { return "" } if c.EnvPrefix != "" { return c.EnvPrefix } baseName := filepath.Base(c.Cmd) name := strings.TrimSuffix(baseName, filepath.Ext(baseName)) prefix := strings.ToUpper(name) + "_" return strings.ReplaceAll(prefix, "-", "_") } func (c *Config) getCommand() *exec.Cmd { cmd := exec.Command(c.Cmd, c.Args...) cmd.Env = []string{} if envVarPrefix := c.getEnvVarPrefix(); envVarPrefix != "" { if envVarPrefix == "*" { logger.Debug(logSender, "", "sharing all the environment variables with plugin %q", c.Cmd) cmd.Env = append(cmd.Env, os.Environ()...) return cmd } logger.Debug(logSender, "", "adding env vars with prefix %q for plugin %q", envVarPrefix, c.Cmd) for _, val := range os.Environ() { if strings.HasPrefix(val, envVarPrefix) { cmd.Env = append(cmd.Env, val) } } } logger.Debug(logSender, "", "additional env vars for plugin %q: %+v", c.Cmd, c.EnvVars) for _, key := range c.EnvVars { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, os.Getenv(key))) } return cmd } func (c *Config) newKMSPluginSecretProvider(base kms.BaseSecret, url, masterKey string) kms.SecretProvider { return &kmsPluginSecretProvider{ BaseSecret: base, URL: url, MasterKey: masterKey, config: c, } } // Manager handles enabled plugins type Manager struct { closed atomic.Bool done chan bool // List of configured plugins Configs []Config `json:"plugins" mapstructure:"plugins"` notifLock sync.RWMutex notifiers []*notifierPlugin kmsLock sync.RWMutex kms []*kmsPlugin authLock sync.RWMutex auths []*authPlugin searcherLock sync.RWMutex searcher *searcherPlugin ipFilterLock sync.RWMutex filter *ipFilterPlugin authScopes int hasSearcher bool hasNotifiers bool hasAuths bool hasIPFilter bool concurrencyGuard chan struct{} } // Initialize initializes the configured plugins func Initialize(configs []Config, logLevel string) error { logger.Debug(logSender, "", "initialize") Handler = Manager{ Configs: configs, done: make(chan bool), authScopes: -1, concurrencyGuard: make(chan struct{}, 250), } Handler.closed.Store(false) setLogLevel(logLevel) if len(configs) == 0 { return nil } if err := Handler.validateConfigs(); err != nil { return err } if err := initializePlugins(); err != nil { return err } startCheckTicker() return nil } func initializePlugins() error { kmsID := 0 for idx, config := range Handler.Configs { switch config.Type { case notifier.PluginName: plugin, err := newNotifierPlugin(config) if err != nil { return err } Handler.notifiers = append(Handler.notifiers, plugin) case kmsplugin.PluginName: plugin, err := newKMSPlugin(config) if err != nil { return err } Handler.kms = append(Handler.kms, plugin) Handler.Configs[idx].kmsID = kmsID kmsID++ kms.RegisterSecretProvider(config.KMSOptions.Scheme, config.KMSOptions.EncryptedStatus, Handler.Configs[idx].newKMSPluginSecretProvider) logger.Info(logSender, "", "registered secret provider for scheme %q, encrypted status %q", config.KMSOptions.Scheme, config.KMSOptions.EncryptedStatus) case auth.PluginName: plugin, err := newAuthPlugin(config) if err != nil { return err } Handler.auths = append(Handler.auths, plugin) if Handler.authScopes == -1 { Handler.authScopes = config.AuthOptions.Scope } else { Handler.authScopes |= config.AuthOptions.Scope } case eventsearcher.PluginName: plugin, err := newSearcherPlugin(config) if err != nil { return err } Handler.searcher = plugin case ipfilter.PluginName: plugin, err := newIPFilterPlugin(config) if err != nil { return err } Handler.filter = plugin default: return fmt.Errorf("unsupported plugin type: %v", config.Type) } } return nil } func (m *Manager) validateConfigs() error { kmsSchemes := make(map[string]bool) kmsEncryptions := make(map[string]bool) m.hasSearcher = false m.hasNotifiers = false m.hasAuths = false m.hasIPFilter = false for _, config := range m.Configs { switch config.Type { case kmsplugin.PluginName: if _, ok := kmsSchemes[config.KMSOptions.Scheme]; ok { return fmt.Errorf("invalid KMS configuration, duplicated scheme %q", config.KMSOptions.Scheme) } if _, ok := kmsEncryptions[config.KMSOptions.EncryptedStatus]; ok { return fmt.Errorf("invalid KMS configuration, duplicated encrypted status %q", config.KMSOptions.EncryptedStatus) } kmsSchemes[config.KMSOptions.Scheme] = true kmsEncryptions[config.KMSOptions.EncryptedStatus] = true case eventsearcher.PluginName: if m.hasSearcher { return errors.New("only one eventsearcher plugin can be defined") } m.hasSearcher = true case notifier.PluginName: m.hasNotifiers = true case auth.PluginName: m.hasAuths = true case ipfilter.PluginName: m.hasIPFilter = true } } return nil } // HasAuthenticators returns true if there is at least an auth plugin func (m *Manager) HasAuthenticators() bool { return m.hasAuths } // HasNotifiers returns true if there is at least a notifier plugin func (m *Manager) HasNotifiers() bool { return m.hasNotifiers } // NotifyFsEvent sends the fs event notifications using any defined notifier plugins func (m *Manager) NotifyFsEvent(event *notifier.FsEvent) { m.notifLock.RLock() defer m.notifLock.RUnlock() for _, n := range m.notifiers { n.notifyFsAction(event) } } // NotifyProviderEvent sends the provider event notifications using any defined notifier plugins func (m *Manager) NotifyProviderEvent(event *notifier.ProviderEvent, object Renderer) { m.notifLock.RLock() defer m.notifLock.RUnlock() for _, n := range m.notifiers { n.notifyProviderAction(event, object) } } // NotifyLogEvent sends the log event notifications using any defined notifier plugins func (m *Manager) NotifyLogEvent(event notifier.LogEventType, protocol, username, ip, role string, err error) { if !m.hasNotifiers { return } m.notifLock.RLock() defer m.notifLock.RUnlock() var e *notifier.LogEvent for _, n := range m.notifiers { if slices.Contains(n.config.NotifierOptions.LogEvents, int(event)) { if e == nil { message := "" if err != nil { message = strings.Trim(err.Error(), "\x00") } e = ¬ifier.LogEvent{ Timestamp: time.Now().UnixNano(), Event: event, Protocol: protocol, Username: username, IP: ip, Message: message, Role: role, } } n.notifyLogEvent(e) } } } // HasSearcher returns true if an event searcher plugin is defined func (m *Manager) HasSearcher() bool { return m.hasSearcher } // SearchFsEvents returns the filesystem events matching the specified filters func (m *Manager) SearchFsEvents(searchFilters *eventsearcher.FsEventSearch) ([]byte, error) { if !m.hasSearcher { return nil, ErrNoSearcher } m.searcherLock.RLock() plugin := m.searcher m.searcherLock.RUnlock() return plugin.searchear.SearchFsEvents(searchFilters) } // SearchProviderEvents returns the provider events matching the specified filters func (m *Manager) SearchProviderEvents(searchFilters *eventsearcher.ProviderEventSearch) ([]byte, error) { if !m.hasSearcher { return nil, ErrNoSearcher } m.searcherLock.RLock() plugin := m.searcher m.searcherLock.RUnlock() return plugin.searchear.SearchProviderEvents(searchFilters) } // SearchLogEvents returns the log events matching the specified filters func (m *Manager) SearchLogEvents(searchFilters *eventsearcher.LogEventSearch) ([]byte, error) { if !m.hasSearcher { return nil, ErrNoSearcher } m.searcherLock.RLock() plugin := m.searcher m.searcherLock.RUnlock() return plugin.searchear.SearchLogEvents(searchFilters) } // IsIPBanned returns true if the IP filter plugin does not allow the specified ip. // If no IP filter plugin is defined this method returns false func (m *Manager) IsIPBanned(ip, protocol string) bool { if !m.hasIPFilter { return false } m.ipFilterLock.RLock() plugin := m.filter m.ipFilterLock.RUnlock() if plugin.exited() { logger.Warn(logSender, "", "ip filter plugin is not active, cannot check ip %q", ip) return false } return plugin.filter.CheckIP(ip, protocol) != nil } // ReloadFilter sends a reload request to the IP filter plugin func (m *Manager) ReloadFilter() { if !m.hasIPFilter { return } m.ipFilterLock.RLock() plugin := m.filter m.ipFilterLock.RUnlock() if err := plugin.filter.Reload(); err != nil { logger.Error(logSender, "", "unable to reload IP filter plugin: %v", err) } } func (m *Manager) kmsEncrypt(secret kms.BaseSecret, url string, masterKey string, kmsID int) (string, string, int32, error) { m.kmsLock.RLock() plugin := m.kms[kmsID] m.kmsLock.RUnlock() return plugin.Encrypt(secret, url, masterKey) } func (m *Manager) kmsDecrypt(secret kms.BaseSecret, url string, masterKey string, kmsID int) (string, error) { m.kmsLock.RLock() plugin := m.kms[kmsID] m.kmsLock.RUnlock() return plugin.Decrypt(secret, url, masterKey) } // HasAuthScope returns true if there is an auth plugin that support the specified scope func (m *Manager) HasAuthScope(scope int) bool { if m.authScopes == -1 { return false } return m.authScopes&scope != 0 } // Authenticate tries to authenticate the specified user using an external plugin func (m *Manager) Authenticate(username, password, ip, protocol string, pkey string, tlsCert *x509.Certificate, authScope int, userAsJSON []byte, ) ([]byte, error) { switch authScope { case AuthScopePassword: return m.checkUserAndPass(username, password, ip, protocol, userAsJSON) case AuthScopePublicKey: return m.checkUserAndPublicKey(username, pkey, ip, protocol, userAsJSON) case AuthScopeKeyboardInteractive: return m.checkUserAndKeyboardInteractive(username, ip, protocol, userAsJSON) case AuthScopeTLSCertificate: cert, err := util.EncodeTLSCertToPem(tlsCert) if err != nil { logger.Error(logSender, "", "unable to encode tls certificate to pem: %v", err) return nil, fmt.Errorf("unable to encode tls cert to pem: %w", err) } return m.checkUserAndTLSCert(username, cert, ip, protocol, userAsJSON) default: return nil, fmt.Errorf("unsupported auth scope: %v", authScope) } } // ExecuteKeyboardInteractiveStep executes a keyboard interactive step func (m *Manager) ExecuteKeyboardInteractiveStep(req *KeyboardAuthRequest) (*KeyboardAuthResponse, error) { var plugin *authPlugin m.authLock.Lock() for _, p := range m.auths { if p.config.AuthOptions.Scope&AuthScopePassword != 0 { plugin = p break } } m.authLock.Unlock() if plugin == nil { return nil, errors.New("no auth plugin configured for keyaboard interactive authentication step") } return plugin.sendKeyboardIteractiveRequest(req) } func (m *Manager) checkUserAndPass(username, password, ip, protocol string, userAsJSON []byte) ([]byte, error) { var plugin *authPlugin m.authLock.Lock() for _, p := range m.auths { if p.config.AuthOptions.Scope&AuthScopePassword != 0 { plugin = p break } } m.authLock.Unlock() if plugin == nil { return nil, errors.New("no auth plugin configured for password checking") } return plugin.checkUserAndPass(username, password, ip, protocol, userAsJSON) } func (m *Manager) checkUserAndPublicKey(username, pubKey, ip, protocol string, userAsJSON []byte) ([]byte, error) { var plugin *authPlugin m.authLock.Lock() for _, p := range m.auths { if p.config.AuthOptions.Scope&AuthScopePublicKey != 0 { plugin = p break } } m.authLock.Unlock() if plugin == nil { return nil, errors.New("no auth plugin configured for public key checking") } return plugin.checkUserAndPublicKey(username, pubKey, ip, protocol, userAsJSON) } func (m *Manager) checkUserAndTLSCert(username, tlsCert, ip, protocol string, userAsJSON []byte) ([]byte, error) { var plugin *authPlugin m.authLock.Lock() for _, p := range m.auths { if p.config.AuthOptions.Scope&AuthScopeTLSCertificate != 0 { plugin = p break } } m.authLock.Unlock() if plugin == nil { return nil, errors.New("no auth plugin configured for TLS certificate checking") } return plugin.checkUserAndTLSCertificate(username, tlsCert, ip, protocol, userAsJSON) } func (m *Manager) checkUserAndKeyboardInteractive(username, ip, protocol string, userAsJSON []byte) ([]byte, error) { var plugin *authPlugin m.authLock.Lock() for _, p := range m.auths { if p.config.AuthOptions.Scope&AuthScopeKeyboardInteractive != 0 { plugin = p break } } m.authLock.Unlock() if plugin == nil { return nil, errors.New("no auth plugin configured for keyboard interactive checking") } return plugin.checkUserAndKeyboardInteractive(username, ip, protocol, userAsJSON) } func (m *Manager) checkCrashedPlugins() { m.notifLock.RLock() for idx, n := range m.notifiers { if n.exited() { defer func(cfg Config, index int) { Handler.restartNotifierPlugin(cfg, index) }(n.config, idx) } else { n.sendQueuedEvents() } } m.notifLock.RUnlock() m.kmsLock.RLock() for idx, k := range m.kms { if k.exited() { defer func(cfg Config, index int) { Handler.restartKMSPlugin(cfg, index) }(k.config, idx) } } m.kmsLock.RUnlock() m.authLock.RLock() for idx, a := range m.auths { if a.exited() { defer func(cfg Config, index int) { Handler.restartAuthPlugin(cfg, index) }(a.config, idx) } } m.authLock.RUnlock() if m.hasSearcher { m.searcherLock.RLock() if m.searcher.exited() { defer func(cfg Config) { Handler.restartSearcherPlugin(cfg) }(m.searcher.config) } m.searcherLock.RUnlock() } if m.hasIPFilter { m.ipFilterLock.RLock() if m.filter.exited() { defer func(cfg Config) { Handler.restartIPFilterPlugin(cfg) }(m.filter.config) } m.ipFilterLock.RUnlock() } } func (m *Manager) restartNotifierPlugin(config Config, idx int) { if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed notifier plugin %q, idx: %v", config.Cmd, idx) plugin, err := newNotifierPlugin(config) if err != nil { logger.Error(logSender, "", "unable to restart notifier plugin %q, err: %v", config.Cmd, err) return } m.notifLock.Lock() plugin.fsEvents = m.notifiers[idx].fsEvents plugin.providerEvents = m.notifiers[idx].providerEvents plugin.logEvents = m.notifiers[idx].logEvents m.notifiers[idx] = plugin m.notifLock.Unlock() plugin.sendQueuedEvents() } func (m *Manager) restartKMSPlugin(config Config, idx int) { if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed kms plugin %q, idx: %v", config.Cmd, idx) plugin, err := newKMSPlugin(config) if err != nil { logger.Error(logSender, "", "unable to restart kms plugin %q, err: %v", config.Cmd, err) return } m.kmsLock.Lock() m.kms[idx] = plugin m.kmsLock.Unlock() } func (m *Manager) restartAuthPlugin(config Config, idx int) { if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed auth plugin %q, idx: %v", config.Cmd, idx) plugin, err := newAuthPlugin(config) if err != nil { logger.Error(logSender, "", "unable to restart auth plugin %q, err: %v", config.Cmd, err) return } m.authLock.Lock() m.auths[idx] = plugin m.authLock.Unlock() } func (m *Manager) restartSearcherPlugin(config Config) { if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed searcher plugin %q", config.Cmd) plugin, err := newSearcherPlugin(config) if err != nil { logger.Error(logSender, "", "unable to restart searcher plugin %q, err: %v", config.Cmd, err) return } m.searcherLock.Lock() m.searcher = plugin m.searcherLock.Unlock() } func (m *Manager) restartIPFilterPlugin(config Config) { if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed IP filter plugin %q", config.Cmd) plugin, err := newIPFilterPlugin(config) if err != nil { logger.Error(logSender, "", "unable to restart IP filter plugin %q, err: %v", config.Cmd, err) return } m.ipFilterLock.Lock() m.filter = plugin m.ipFilterLock.Unlock() } func (m *Manager) addTask() { m.concurrencyGuard <- struct{}{} } func (m *Manager) removeTask() { <-m.concurrencyGuard } // Cleanup releases all the active plugins func (m *Manager) Cleanup() { if m.closed.Swap(true) { return } logger.Debug(logSender, "", "cleanup") close(m.done) m.notifLock.Lock() for _, n := range m.notifiers { logger.Debug(logSender, "", "cleanup notifier plugin %v", n.config.Cmd) n.cleanup() } m.notifLock.Unlock() m.kmsLock.Lock() for _, k := range m.kms { logger.Debug(logSender, "", "cleanup kms plugin %v", k.config.Cmd) k.cleanup() } m.kmsLock.Unlock() m.authLock.Lock() for _, a := range m.auths { logger.Debug(logSender, "", "cleanup auth plugin %v", a.config.Cmd) a.cleanup() } m.authLock.Unlock() if m.hasSearcher { m.searcherLock.Lock() logger.Debug(logSender, "", "cleanup searcher plugin %v", m.searcher.config.Cmd) m.searcher.cleanup() m.searcherLock.Unlock() } if m.hasIPFilter { m.ipFilterLock.Lock() logger.Debug(logSender, "", "cleanup IP filter plugin %v", m.filter.config.Cmd) m.filter.cleanup() m.ipFilterLock.Unlock() } } func setLogLevel(logLevel string) { switch logLevel { case "info": pluginsLogLevel = hclog.Info case "warn": pluginsLogLevel = hclog.Warn case "error": pluginsLogLevel = hclog.Error default: pluginsLogLevel = hclog.Debug } } func startCheckTicker() { logger.Debug(logSender, "", "start plugins checker") go func() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-Handler.done: logger.Debug(logSender, "", "handler done, stop plugins checker") return case <-ticker.C: Handler.checkCrashedPlugins() } } }() } ================================================ FILE: internal/plugin/searcher.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package plugin import ( "fmt" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" "github.com/sftpgo/sdk/plugin/eventsearcher" "github.com/drakkan/sftpgo/v2/internal/logger" ) type searcherPlugin struct { config Config searchear eventsearcher.Searcher client *plugin.Client } func newSearcherPlugin(config Config) (*searcherPlugin, error) { p := &searcherPlugin{ config: config, } if err := p.initialize(); err != nil { logger.Warn(logSender, "", "unable to create events searcher plugin: %v, config %+v", err, config) return nil, err } return p, nil } func (p *searcherPlugin) exited() bool { return p.client.Exited() } func (p *searcherPlugin) cleanup() { p.client.Kill() } func (p *searcherPlugin) initialize() error { killProcess(p.config.Cmd) logger.Debug(logSender, "", "create new searcher plugin %q", p.config.Cmd) secureConfig, err := p.config.getSecureConfig() if err != nil { return err } client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: eventsearcher.Handshake, Plugins: eventsearcher.PluginMap, Cmd: p.config.getCommand(), SkipHostEnv: true, AllowedProtocols: []plugin.Protocol{ plugin.ProtocolGRPC, }, AutoMTLS: p.config.AutoMTLS, SecureConfig: secureConfig, Managed: false, Logger: &logger.HCLogAdapter{ Logger: hclog.New(&hclog.LoggerOptions{ Name: fmt.Sprintf("%v.%v", logSender, eventsearcher.PluginName), Level: pluginsLogLevel, DisableTime: true, }), }, }) rpcClient, err := client.Client() if err != nil { logger.Debug(logSender, "", "unable to get rpc client for plugin %q: %v", p.config.Cmd, err) return err } raw, err := rpcClient.Dispense(eventsearcher.PluginName) if err != nil { logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", eventsearcher.PluginName, p.config.Cmd, err) return err } p.client = client p.searchear = raw.(eventsearcher.Searcher) return nil } ================================================ FILE: internal/plugin/util.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package plugin import ( "github.com/shirou/gopsutil/v3/process" "github.com/drakkan/sftpgo/v2/internal/logger" ) func killProcess(processPath string) { procs, err := process.Processes() if err != nil { return } for _, p := range procs { cmdLine, err := p.Exe() if err == nil { if cmdLine == processPath { err = p.Kill() logger.Debug(logSender, "", "killed process %v, pid %v, err %v", cmdLine, p.Pid, err) return } } } logger.Debug(logSender, "", "no match for plugin process %v", processPath) } ================================================ FILE: internal/service/service.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package service allows to start and stop the SFTPGo service package service import ( "errors" "fmt" "os" "path/filepath" "github.com/rs/zerolog" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( logSender = "service" ) var ( graceTime int ) // Service defines the SFTPGo service type Service struct { ConfigDir string ConfigFile string LogFilePath string LogMaxSize int LogMaxBackups int LogMaxAge int PortableMode int PortableUser dataprovider.User LogCompress bool LogLevel string LogUTCTime bool LoadDataClean bool LoadDataFrom string LoadDataMode int LoadDataQuotaScan int Shutdown chan bool Error error } func (s *Service) initLogger() { var logLevel zerolog.Level switch s.LogLevel { case "info": logLevel = zerolog.InfoLevel case "warn": logLevel = zerolog.WarnLevel case "error": logLevel = zerolog.ErrorLevel default: logLevel = zerolog.DebugLevel } if !filepath.IsAbs(s.LogFilePath) && util.IsFileInputValid(s.LogFilePath) { s.LogFilePath = filepath.Join(s.ConfigDir, s.LogFilePath) } logger.InitLogger(s.LogFilePath, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogCompress, s.LogUTCTime, logLevel) if s.PortableMode == 1 { logger.EnableConsoleLogger(logLevel) if s.LogFilePath == "" { logger.DisableLogger() } } } // Start initializes and starts the service func (s *Service) Start() error { s.initLogger() logger.Info(logSender, "", "starting SFTPGo %s, config dir: %s, config file: %s, log max size: %d log max backups: %d "+ "log max age: %d log level: %s, log compress: %t, log utc time: %t, load data from: %q, grace time: %d secs", version.GetAsString(), s.ConfigDir, s.ConfigFile, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogLevel, s.LogCompress, s.LogUTCTime, s.LoadDataFrom, graceTime) // in portable mode we don't read configuration from file if s.PortableMode != 1 { err := config.LoadConfig(s.ConfigDir, s.ConfigFile) if err != nil { logger.Error(logSender, "", "error loading configuration: %v", err) return err } } if !config.HasServicesToStart() { const infoString = "no service configured, nothing to do" logger.Info(logSender, "", infoString) logger.InfoToConsole(infoString) return errors.New(infoString) } if err := s.initializeServices(); err != nil { return err } s.startServices() go common.Config.ExecuteStartupHook() //nolint:errcheck return nil } func (s *Service) initializeServices() error { providerConf := config.GetProviderConf() kmsConfig := config.GetKMSConfig() err := kmsConfig.Initialize() if err != nil { logger.Error(logSender, "", "unable to initialize KMS: %v", err) logger.ErrorToConsole("unable to initialize KMS: %v", err) return err } // We may have KMS plugins and their schema needs to be registered before // initializing the data provider which may contain KMS secrets. if err := plugin.Initialize(config.GetPluginsConfig(), s.LogLevel); err != nil { logger.Error(logSender, "", "unable to initialize plugin system: %v", err) logger.ErrorToConsole("unable to initialize plugin system: %v", err) return err } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.Error(logSender, "", "unable to initialize MFA: %v", err) logger.ErrorToConsole("unable to initialize MFA: %v", err) return err } err = dataprovider.Initialize(providerConf, s.ConfigDir, s.PortableMode == 0) if err != nil { logger.Error(logSender, "", "error initializing data provider: %v", err) logger.ErrorToConsole("error initializing data provider: %v", err) return err } smtpConfig := config.GetSMTPConfig() err = smtpConfig.Initialize(s.ConfigDir, s.PortableMode != 1) if err != nil { logger.Error(logSender, "", "unable to initialize SMTP configuration: %v", err) logger.ErrorToConsole("unable to initialize SMTP configuration: %v", err) return err } err = common.Initialize(config.GetCommonConfig(), providerConf.GetShared()) if err != nil { logger.Error(logSender, "", "%v", err) logger.ErrorToConsole("%v", err) return err } if s.PortableMode == 1 { // create the user for portable mode err = dataprovider.AddUser(&s.PortableUser, dataprovider.ActionExecutorSystem, "", "") if err != nil { logger.ErrorToConsole("error adding portable user: %v", err) return err } } else { acmeConfig := config.GetACMEConfig() err = acme.Initialize(acmeConfig, s.ConfigDir, true) if err != nil { logger.Error(logSender, "", "error initializing ACME configuration: %v", err) logger.ErrorToConsole("error initializing ACME configuration: %v", err) return err } } httpConfig := config.GetHTTPConfig() err = httpConfig.Initialize(s.ConfigDir) if err != nil { logger.Error(logSender, "", "error initializing http client: %v", err) logger.ErrorToConsole("error initializing http client: %v", err) return err } commandConfig := config.GetCommandConfig() if err := commandConfig.Initialize(); err != nil { logger.Error(logSender, "", "error initializing commands configuration: %v", err) logger.ErrorToConsole("error initializing commands configuration: %v", err) return err } return nil } func (s *Service) startServices() { err := s.LoadInitialData() if err != nil { logger.Error(logSender, "", "unable to load initial data: %v", err) logger.ErrorToConsole("unable to load initial data: %v", err) } sftpdConf := config.GetSFTPDConfig() ftpdConf := config.GetFTPDConfig() httpdConf := config.GetHTTPDConfig() webDavDConf := config.GetWebDAVDConfig() telemetryConf := config.GetTelemetryConfig() if sftpdConf.ShouldBind() { go func() { redactedConf := sftpdConf redactedConf.KeyboardInteractiveHook = util.GetRedactedURL(sftpdConf.KeyboardInteractiveHook) logger.Info(logSender, "", "initializing SFTP server with config %+v", redactedConf) if err := sftpdConf.Initialize(s.ConfigDir); err != nil { logger.Error(logSender, "", "could not start SFTP server: %v", err) logger.ErrorToConsole("could not start SFTP server: %v", err) s.Error = err } s.Shutdown <- true }() } else { logger.Info(logSender, "", "SFTP server not started, disabled in config file") } if httpdConf.ShouldBind() { go func() { providerConf := config.GetProviderConf() if err := httpdConf.Initialize(s.ConfigDir, providerConf.GetShared()); err != nil { logger.Error(logSender, "", "could not start HTTP server: %v", err) logger.ErrorToConsole("could not start HTTP server: %v", err) s.Error = err } s.Shutdown <- true }() } else { logger.Info(logSender, "", "HTTP server not started, disabled in config file") if s.PortableMode != 1 { logger.InfoToConsole("HTTP server not started, disabled in config file") } } if ftpdConf.ShouldBind() { go func() { if err := ftpdConf.Initialize(s.ConfigDir); err != nil { logger.Error(logSender, "", "could not start FTP server: %v", err) logger.ErrorToConsole("could not start FTP server: %v", err) s.Error = err } s.Shutdown <- true }() } else { logger.Info(logSender, "", "FTP server not started, disabled in config file") } if webDavDConf.ShouldBind() { go func() { if err := webDavDConf.Initialize(s.ConfigDir); err != nil { logger.Error(logSender, "", "could not start WebDAV server: %v", err) logger.ErrorToConsole("could not start WebDAV server: %v", err) s.Error = err } s.Shutdown <- true }() } else { logger.Info(logSender, "", "WebDAV server not started, disabled in config file") } if telemetryConf.ShouldBind() { go func() { if err := telemetryConf.Initialize(s.ConfigDir); err != nil { logger.Error(logSender, "", "could not start telemetry server: %v", err) logger.ErrorToConsole("could not start telemetry server: %v", err) s.Error = err } s.Shutdown <- true }() } else { logger.Info(logSender, "", "telemetry server not started, disabled in config file") if s.PortableMode != 1 { logger.InfoToConsole("telemetry server not started, disabled in config file") } } } // Wait blocks until the service exits func (s *Service) Wait() { if s.PortableMode != 1 { registerSignals() } <-s.Shutdown } // Stop terminates the service unblocking the Wait method func (s *Service) Stop() { close(s.Shutdown) logger.Debug(logSender, "", "Service stopped") } // LoadInitialData if a data file is set func (s *Service) LoadInitialData() error { if s.LoadDataFrom == "" { return nil } if !filepath.IsAbs(s.LoadDataFrom) { return fmt.Errorf("invalid input_file %q, it must be an absolute path", s.LoadDataFrom) } if s.LoadDataMode < 0 || s.LoadDataMode > 1 { return fmt.Errorf("invalid loaddata-mode %v", s.LoadDataMode) } if s.LoadDataQuotaScan < 0 || s.LoadDataQuotaScan > 2 { return fmt.Errorf("invalid loaddata-scan %v", s.LoadDataQuotaScan) } info, err := os.Stat(s.LoadDataFrom) if err != nil { return fmt.Errorf("unable to stat file %q: %w", s.LoadDataFrom, err) } if info.Size() > httpd.MaxRestoreSize { return fmt.Errorf("unable to restore input file %q size too big: %d/%d bytes", s.LoadDataFrom, info.Size(), httpd.MaxRestoreSize) } content, err := os.ReadFile(s.LoadDataFrom) if err != nil { return fmt.Errorf("unable to read input file %q: %w", s.LoadDataFrom, err) } dump, err := dataprovider.ParseDumpData(content) if err != nil { return fmt.Errorf("unable to parse file to restore %q: %w", s.LoadDataFrom, err) } err = s.restoreDump(&dump) if err != nil { return err } logger.Info(logSender, "", "data loaded from file %q mode: %v", s.LoadDataFrom, s.LoadDataMode) logger.InfoToConsole("data loaded from file %q mode: %v", s.LoadDataFrom, s.LoadDataMode) if s.LoadDataClean { err = os.Remove(s.LoadDataFrom) if err == nil { logger.Info(logSender, "", "file %q deleted after successful load", s.LoadDataFrom) logger.InfoToConsole("file %q deleted after successful load", s.LoadDataFrom) } else { logger.Warn(logSender, "", "unable to delete file %q after successful load: %v", s.LoadDataFrom, err) logger.WarnToConsole("unable to delete file %q after successful load: %v", s.LoadDataFrom, err) } } return nil } func (s *Service) restoreDump(dump *dataprovider.BackupData) error { err := httpd.RestoreConfigs(dump.Configs, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore configs from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreIPListEntries(dump.IPLists, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore IP list entries from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreRoles(dump.Roles, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore roles from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreFolders(dump.Folders, s.LoadDataFrom, s.LoadDataMode, s.LoadDataQuotaScan, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore folders from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreGroups(dump.Groups, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore groups from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreUsers(dump.Users, s.LoadDataFrom, s.LoadDataMode, s.LoadDataQuotaScan, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore users from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreAdmins(dump.Admins, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore admins from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreAPIKeys(dump.APIKeys, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore API keys from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreShares(dump.Shares, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore API keys from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreEventActions(dump.EventActions, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") if err != nil { return fmt.Errorf("unable to restore event actions from file %q: %v", s.LoadDataFrom, err) } err = httpd.RestoreEventRules(dump.EventRules, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "", dump.Version) if err != nil { return fmt.Errorf("unable to restore event rules from file %q: %v", s.LoadDataFrom, err) } return nil } // SetGraceTime sets the grace time func SetGraceTime(val int) { graceTime = val } ================================================ FILE: internal/service/service_portable.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !noportable package service import ( "fmt" "math/rand" "slices" "strings" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) // StartPortableMode starts the service in portable mode func (s *Service) StartPortableMode(sftpdPort, ftpPort, webdavPort, httpPort int, enabledSSHCommands []string, ftpsCert, ftpsKey, webDavCert, webDavKey, httpsCert, httpsKey string) error { if s.PortableMode != 1 { return fmt.Errorf("service is not configured for portable mode") } err := config.LoadConfig(s.ConfigDir, s.ConfigFile) if err != nil { fmt.Printf("error loading configuration file: %v using defaults\n", err) } kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { return err } printablePassword := s.configurePortableUser() dataProviderConf := config.GetProviderConf() dataProviderConf.Driver = dataprovider.MemoryDataProviderName dataProviderConf.Name = "" config.SetProviderConf(dataProviderConf) httpdConf := config.GetHTTPDConfig() for idx := range httpdConf.Bindings { httpdConf.Bindings[idx].Port = 0 } config.SetHTTPDConfig(httpdConf) telemetryConf := config.GetTelemetryConfig() telemetryConf.BindPort = 0 config.SetTelemetryConfig(telemetryConf) configurePortableSFTPService(sftpdPort, enabledSSHCommands) configurePortableFTPService(ftpPort, ftpsCert, ftpsKey) configurePortableWebDAVService(webdavPort, webDavCert, webDavKey) configurePortableHTTPService(httpPort, httpsCert, httpsKey) err = s.Start() if err != nil { return err } if httpPort >= 0 { admin := &dataprovider.Admin{ Username: util.GenerateUniqueID(), Password: util.GenerateUniqueID(), Status: 0, Permissions: []string{dataprovider.PermAdminAny}, } if err := dataprovider.AddAdmin(admin, dataprovider.ActionExecutorSystem, "", ""); err != nil { return err } } logger.InfoToConsole("Portable mode ready, user: %q, password: %q, public keys: %v, directory: %q, "+ "permissions: %+v, file patterns filters: %+v %v", s.PortableUser.Username, printablePassword, s.PortableUser.PublicKeys, s.getPortableDirToServe(), s.PortableUser.Permissions, s.PortableUser.Filters.FilePatterns, s.getServiceOptionalInfoString()) return nil } func (s *Service) getServiceOptionalInfoString() string { var info strings.Builder if config.GetSFTPDConfig().Bindings[0].IsValid() { fmt.Fprintf(&info, "SFTP port: %d ", config.GetSFTPDConfig().Bindings[0].Port) } if config.GetFTPDConfig().Bindings[0].IsValid() { fmt.Fprintf(&info, "FTP port: %d ", config.GetFTPDConfig().Bindings[0].Port) } if config.GetWebDAVDConfig().Bindings[0].IsValid() { scheme := "http" if config.GetWebDAVDConfig().CertificateFile != "" && config.GetWebDAVDConfig().CertificateKeyFile != "" { scheme = "https" } fmt.Fprintf(&info, "WebDAV URL: %v://:%v/ ", scheme, config.GetWebDAVDConfig().Bindings[0].Port) } if config.GetHTTPDConfig().Bindings[0].IsValid() { scheme := "http" if config.GetHTTPDConfig().CertificateFile != "" && config.GetHTTPDConfig().CertificateKeyFile != "" { scheme = "https" } fmt.Fprintf(&info, "WebClient URL: %s://:%d/ ", scheme, config.GetHTTPDConfig().Bindings[0].Port) } return info.String() } func (s *Service) getPortableDirToServe() string { switch s.PortableUser.FsConfig.Provider { case sdk.S3FilesystemProvider: return s.PortableUser.FsConfig.S3Config.KeyPrefix case sdk.GCSFilesystemProvider: return s.PortableUser.FsConfig.GCSConfig.KeyPrefix case sdk.AzureBlobFilesystemProvider: return s.PortableUser.FsConfig.AzBlobConfig.KeyPrefix case sdk.SFTPFilesystemProvider: return s.PortableUser.FsConfig.SFTPConfig.Prefix case sdk.HTTPFilesystemProvider: return "/" default: return s.PortableUser.HomeDir } } // configures the portable user and return the printable password if any func (s *Service) configurePortableUser() string { if s.PortableUser.Username == "" { s.PortableUser.Username = "user" } printablePassword := "" if s.PortableUser.Password != "" { printablePassword = "[redacted]" } if len(s.PortableUser.PublicKeys) == 0 && s.PortableUser.Password == "" { s.PortableUser.Password = util.GenerateUniqueID() printablePassword = s.PortableUser.Password } s.PortableUser.Filters.WebClient = []string{sdk.WebClientSharesDisabled, sdk.WebClientInfoChangeDisabled, sdk.WebClientPubKeyChangeDisabled, sdk.WebClientPasswordChangeDisabled, sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientMFADisabled, sdk.WebClientPasswordResetDisabled, sdk.WebClientTLSCertChangeDisabled, } if !s.PortableUser.HasAnyPerm([]string{dataprovider.PermUpload, dataprovider.PermOverwrite}, "/") { s.PortableUser.Filters.WebClient = append(s.PortableUser.Filters.WebClient, sdk.WebClientWriteDisabled) } s.configurePortableSecrets() return printablePassword } func (s *Service) configurePortableSecrets() { // we created the user before to initialize the KMS so we need to create the secret here switch s.PortableUser.FsConfig.Provider { case sdk.S3FilesystemProvider: payload := s.PortableUser.FsConfig.S3Config.AccessSecret.GetPayload() s.PortableUser.FsConfig.S3Config.AccessSecret = getSecretFromString(payload) case sdk.GCSFilesystemProvider: payload := s.PortableUser.FsConfig.GCSConfig.Credentials.GetPayload() s.PortableUser.FsConfig.GCSConfig.Credentials = getSecretFromString(payload) case sdk.AzureBlobFilesystemProvider: payload := s.PortableUser.FsConfig.AzBlobConfig.AccountKey.GetPayload() s.PortableUser.FsConfig.AzBlobConfig.AccountKey = getSecretFromString(payload) payload = s.PortableUser.FsConfig.AzBlobConfig.SASURL.GetPayload() s.PortableUser.FsConfig.AzBlobConfig.SASURL = getSecretFromString(payload) case sdk.CryptedFilesystemProvider: payload := s.PortableUser.FsConfig.CryptConfig.Passphrase.GetPayload() s.PortableUser.FsConfig.CryptConfig.Passphrase = getSecretFromString(payload) case sdk.SFTPFilesystemProvider: payload := s.PortableUser.FsConfig.SFTPConfig.Password.GetPayload() s.PortableUser.FsConfig.SFTPConfig.Password = getSecretFromString(payload) payload = s.PortableUser.FsConfig.SFTPConfig.PrivateKey.GetPayload() s.PortableUser.FsConfig.SFTPConfig.PrivateKey = getSecretFromString(payload) payload = s.PortableUser.FsConfig.SFTPConfig.KeyPassphrase.GetPayload() s.PortableUser.FsConfig.SFTPConfig.KeyPassphrase = getSecretFromString(payload) case sdk.HTTPFilesystemProvider: payload := s.PortableUser.FsConfig.HTTPConfig.Password.GetPayload() s.PortableUser.FsConfig.HTTPConfig.Password = getSecretFromString(payload) payload = s.PortableUser.FsConfig.HTTPConfig.APIKey.GetPayload() s.PortableUser.FsConfig.HTTPConfig.APIKey = getSecretFromString(payload) } } func getSecretFromString(payload string) *kms.Secret { if payload != "" { return kms.NewPlainSecret(payload) } return kms.NewEmptySecret() } func configurePortableSFTPService(port int, enabledSSHCommands []string) { sftpdConf := config.GetSFTPDConfig() if len(sftpdConf.Bindings) == 0 { sftpdConf.Bindings = append(sftpdConf.Bindings, sftpd.Binding{}) } if port > 0 { sftpdConf.Bindings[0].Port = port } else if port == 0 { // dynamic ports starts from 49152 sftpdConf.Bindings[0].Port = 49152 + rand.Intn(15000) } else { sftpdConf.Bindings[0].Port = 0 } if slices.Contains(enabledSSHCommands, "*") { sftpdConf.EnabledSSHCommands = sftpd.GetSupportedSSHCommands() } else { sftpdConf.EnabledSSHCommands = enabledSSHCommands } config.SetSFTPDConfig(sftpdConf) } func configurePortableFTPService(port int, cert, key string) { ftpConf := config.GetFTPDConfig() if len(ftpConf.Bindings) == 0 { ftpConf.Bindings = append(ftpConf.Bindings, ftpd.Binding{}) } if port > 0 { ftpConf.Bindings[0].Port = port } else if port == 0 { ftpConf.Bindings[0].Port = 49152 + rand.Intn(15000) } else { ftpConf.Bindings[0].Port = 0 } ftpConf.Bindings[0].CertificateFile = cert ftpConf.Bindings[0].CertificateKeyFile = key config.SetFTPDConfig(ftpConf) } func configurePortableWebDAVService(port int, cert, key string) { webDavConf := config.GetWebDAVDConfig() if len(webDavConf.Bindings) == 0 { webDavConf.Bindings = append(webDavConf.Bindings, webdavd.Binding{}) } if port > 0 { webDavConf.Bindings[0].Port = port } else if port == 0 { webDavConf.Bindings[0].Port = 49152 + rand.Intn(15000) } else { webDavConf.Bindings[0].Port = 0 } webDavConf.Bindings[0].CertificateFile = cert webDavConf.Bindings[0].CertificateKeyFile = key if cert != "" && key != "" { webDavConf.Bindings[0].EnableHTTPS = true } config.SetWebDAVDConfig(webDavConf) } func configurePortableHTTPService(port int, cert, key string) { httpdConf := config.GetHTTPDConfig() if len(httpdConf.Bindings) == 0 { httpdConf.Bindings = append(httpdConf.Bindings, httpd.Binding{}) } if port > 0 { httpdConf.Bindings[0].Port = port } else if port == 0 { httpdConf.Bindings[0].Port = 49152 + rand.Intn(15000) } else { httpdConf.Bindings[0].Port = 0 } httpdConf.Bindings[0].CertificateFile = cert httpdConf.Bindings[0].CertificateKeyFile = key if cert != "" && key != "" { httpdConf.Bindings[0].EnableHTTPS = true } httpdConf.Bindings[0].EnableWebAdmin = false httpdConf.Bindings[0].EnableWebClient = true httpdConf.Bindings[0].EnableRESTAPI = false httpdConf.Bindings[0].RenderOpenAPI = false config.SetHTTPDConfig(httpdConf) } ================================================ FILE: internal/service/service_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package service import ( "fmt" "os" "path/filepath" "strings" "time" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/eventlog" "golang.org/x/sys/windows/svc/mgr" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/telemetry" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) const ( serviceName = "SFTPGo" serviceDesc = "Full-featured and highly configurable file transfer server" rotateLogCmd = svc.Cmd(128) acceptRotateLog = svc.Accepted(rotateLogCmd) ) // Status defines service status type Status uint8 // Supported values for service status const ( StatusUnknown Status = iota StatusRunning StatusStopped StatusPaused StatusStartPending StatusPausePending StatusContinuePending StatusStopPending ) type WindowsService struct { Service Service isInteractive bool } func (s Status) String() string { switch s { case StatusRunning: return "running" case StatusStopped: return "stopped" case StatusStartPending: return "start pending" case StatusPausePending: return "pause pending" case StatusPaused: return "paused" case StatusContinuePending: return "continue pending" case StatusStopPending: return "stop pending" default: return "unknown" } } func (s *WindowsService) handleExit(wasStopped chan bool) { s.Service.Wait() select { case <-wasStopped: // the service was stopped nothing to do logger.Info(logSender, "", "Windows Service was stopped") return default: // the server failed while running, we must be sure to exit the process. // The defined recovery action will be executed. logger.Info(logSender, "", "Service wait ended, error: %v", s.Service.Error) if s.Service.Error == nil { os.Exit(0) } else { os.Exit(1) } } } func (s *WindowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { changes <- svc.Status{State: svc.StartPending} go func() { if err := s.Service.Start(); err != nil { logger.Error(logSender, "", "Windows service failed to start, error: %v", err) s.Service.Error = err s.Service.Shutdown <- true return } logger.Info(logSender, "", "Windows service started") cmdsAccepted := svc.AcceptStop | svc.AcceptShutdown | svc.AcceptParamChange | acceptRotateLog changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} }() wasStopped := make(chan bool, 1) go s.handleExit(wasStopped) changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} loop: for { c := <-r switch c.Cmd { case svc.Interrogate: logger.Debug(logSender, "", "Received service interrogate request, current status: %v", c.CurrentStatus) changes <- c.CurrentStatus case svc.Stop, svc.Shutdown: logger.Debug(logSender, "", "Received service stop request") changes <- svc.Status{State: svc.StopPending} wasStopped <- true s.Service.Stop() plugin.Handler.Cleanup() common.WaitForTransfers(graceTime) break loop case svc.ParamChange: logger.Debug(logSender, "", "Received reload request") err := dataprovider.ReloadConfig() if err != nil { logger.Warn(logSender, "", "error reloading dataprovider configuration: %v", err) } err = httpd.ReloadCertificateMgr() if err != nil { logger.Warn(logSender, "", "error reloading cert manager: %v", err) } err = ftpd.ReloadCertificateMgr() if err != nil { logger.Warn(logSender, "", "error reloading FTPD cert manager: %v", err) } err = webdavd.ReloadCertificateMgr() if err != nil { logger.Warn(logSender, "", "error reloading WebDAV cert manager: %v", err) } err = telemetry.ReloadCertificateMgr() if err != nil { logger.Warn(logSender, "", "error reloading telemetry cert manager: %v", err) } err = common.Reload() if err != nil { logger.Warn(logSender, "", "error reloading common configs: %v", err) } err = sftpd.Reload() if err != nil { logger.Warn(logSender, "", "error reloading sftpd revoked certificates: %v", err) } case rotateLogCmd: logger.Debug(logSender, "", "Received log file rotation request") err := logger.RotateLogFile() if err != nil { logger.Warn(logSender, "", "error rotating log file: %v", err) } default: continue loop } } return false, 0 } func (s *WindowsService) RunService() error { exePath, err := s.getExePath() if err != nil { return err } isService, err := svc.IsWindowsService() if err != nil { return err } s.isInteractive = !isService dir := filepath.Dir(exePath) if err = os.Chdir(dir); err != nil { return err } if s.isInteractive { return s.Start() } return svc.Run(serviceName, s) } func (s *WindowsService) Start() error { m, err := mgr.Connect() if err != nil { return err } defer m.Disconnect() service, err := m.OpenService(serviceName) if err != nil { return fmt.Errorf("could not access service: %v", err) } defer service.Close() err = service.Start() if err != nil { return fmt.Errorf("could not start service: %v", err) } return nil } func (s *WindowsService) Reload() error { m, err := mgr.Connect() if err != nil { return err } defer m.Disconnect() service, err := m.OpenService(serviceName) if err != nil { return fmt.Errorf("could not access service: %v", err) } defer service.Close() _, err = service.Control(svc.ParamChange) if err != nil { return fmt.Errorf("could not send control=%d: %v", svc.ParamChange, err) } return nil } func (s *WindowsService) RotateLogFile() error { m, err := mgr.Connect() if err != nil { return err } defer m.Disconnect() service, err := m.OpenService(serviceName) if err != nil { return fmt.Errorf("could not access service: %v", err) } defer service.Close() _, err = service.Control(rotateLogCmd) if err != nil { return fmt.Errorf("could not send control=%d: %v", rotateLogCmd, err) } return nil } func (s *WindowsService) Install(args ...string) error { exePath, err := s.getExePath() if err != nil { return err } m, err := mgr.Connect() if err != nil { return err } defer m.Disconnect() service, err := m.OpenService(serviceName) if err == nil { service.Close() return fmt.Errorf("service %s already exists", serviceName) } config := mgr.Config{ DisplayName: serviceName, Description: serviceDesc, StartType: mgr.StartAutomatic} service, err = m.CreateService(serviceName, exePath, config, args...) if err != nil { return err } defer service.Close() err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info) if err != nil { if !strings.Contains(err.Error(), "exists") { service.Delete() return fmt.Errorf("SetupEventLogSource() failed: %s", err) } } recoveryActions := []mgr.RecoveryAction{ { Type: mgr.ServiceRestart, Delay: 5 * time.Second, }, { Type: mgr.ServiceRestart, Delay: 60 * time.Second, }, { Type: mgr.ServiceRestart, Delay: 90 * time.Second, }, } err = service.SetRecoveryActions(recoveryActions, 300) if err != nil { service.Delete() return fmt.Errorf("unable to set recovery actions: %v", err) } return nil } func (s *WindowsService) Uninstall() error { m, err := mgr.Connect() if err != nil { return err } defer m.Disconnect() service, err := m.OpenService(serviceName) if err != nil { return fmt.Errorf("service %s is not installed", serviceName) } defer service.Close() err = service.Delete() if err != nil { return err } err = eventlog.Remove(serviceName) if err != nil { return fmt.Errorf("RemoveEventLogSource() failed: %s", err) } return nil } func (s *WindowsService) Stop() error { m, err := mgr.Connect() if err != nil { return err } defer m.Disconnect() service, err := m.OpenService(serviceName) if err != nil { return fmt.Errorf("could not access service: %v", err) } defer service.Close() status, err := service.Control(svc.Stop) if err != nil { return fmt.Errorf("could not send control=%d: %v", svc.Stop, err) } timeout := time.Now().Add(10 * time.Second) for status.State != svc.Stopped { if timeout.Before(time.Now()) { return fmt.Errorf("timeout waiting for service to go to state=%d", svc.Stopped) } time.Sleep(300 * time.Millisecond) status, err = service.Query() if err != nil { return fmt.Errorf("could not retrieve service status: %v", err) } } return nil } func (s *WindowsService) Status() (Status, error) { m, err := mgr.Connect() if err != nil { return StatusUnknown, err } defer m.Disconnect() service, err := m.OpenService(serviceName) if err != nil { return StatusUnknown, fmt.Errorf("could not access service: %v", err) } defer service.Close() status, err := service.Query() if err != nil { return StatusUnknown, fmt.Errorf("could not query service status: %v", err) } switch status.State { case svc.StartPending: return StatusStartPending, nil case svc.Running: return StatusRunning, nil case svc.PausePending: return StatusPausePending, nil case svc.Paused: return StatusPaused, nil case svc.ContinuePending: return StatusContinuePending, nil case svc.StopPending: return StatusStopPending, nil case svc.Stopped: return StatusStopped, nil default: return StatusUnknown, fmt.Errorf("unknown status %v", status) } } func (s *WindowsService) getExePath() (string, error) { return os.Executable() } ================================================ FILE: internal/service/signals_unix.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !windows package service import ( "os" "os/signal" "syscall" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/telemetry" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) func registerSignals() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGUSR1) go func() { for sig := range c { switch sig { case syscall.SIGHUP: handleSIGHUP() case syscall.SIGUSR1: handleSIGUSR1() case syscall.SIGINT, syscall.SIGTERM: handleInterrupt() } } }() } func handleSIGHUP() { logger.Debug(logSender, "", "Received reload request") err := dataprovider.ReloadConfig() if err != nil { logger.Warn(logSender, "", "error reloading dataprovider configuration: %v", err) } err = httpd.ReloadCertificateMgr() if err != nil { logger.Warn(logSender, "", "error reloading cert manager: %v", err) } err = ftpd.ReloadCertificateMgr() if err != nil { logger.Warn(logSender, "", "error reloading FTPD cert manager: %v", err) } err = webdavd.ReloadCertificateMgr() if err != nil { logger.Warn(logSender, "", "error reloading WebDAV cert manager: %v", err) } err = telemetry.ReloadCertificateMgr() if err != nil { logger.Warn(logSender, "", "error reloading telemetry cert manager: %v", err) } err = common.Reload() if err != nil { logger.Warn(logSender, "", "error reloading common configs: %v", err) } err = sftpd.Reload() if err != nil { logger.Warn(logSender, "", "error reloading sftpd revoked certificates: %v", err) } } func handleSIGUSR1() { logger.Debug(logSender, "", "Received log file rotation request") err := logger.RotateLogFile() if err != nil { logger.Warn(logSender, "", "error rotating log file: %v", err) } } func handleInterrupt() { logger.Debug(logSender, "", "Received interrupt request") plugin.Handler.Cleanup() common.WaitForTransfers(graceTime) os.Exit(0) } ================================================ FILE: internal/service/signals_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package service import ( "os" "os/signal" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" ) func registerSignals() { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { for range c { logger.Debug(logSender, "", "Received interrupt request") plugin.Handler.Cleanup() common.WaitForTransfers(graceTime) os.Exit(0) } }() } ================================================ FILE: internal/sftpd/cryptfs_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd_test import ( "crypto/sha256" "fmt" "net/http" "os" "path" "path/filepath" "testing" "time" "github.com/minio/sio" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpdtest" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( testPassphrase = "test passphrase" ) func TestBasicSFTPCryptoHandling(t *testing.T) { usePubKey := false u := getTestUserWithCryptFs(usePubKey) u.QuotaSize = 6553600 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) encryptedFileSize, err := getEncryptedFileSize(testFileSize) assert.NoError(t, err) expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize expectedQuotaFiles := user.UsedQuotaFiles + 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client) assert.Error(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) initialHash, err := computeHashForFile(sha256.New(), testFilePath) assert.NoError(t, err) downloadedFileHash, err := computeHashForFile(sha256.New(), localDownloadPath) assert.NoError(t, err) assert.Equal(t, initialHash, downloadedFileHash) info, err := os.Stat(filepath.Join(user.HomeDir, testFileName)) if assert.NoError(t, err) { assert.Equal(t, encryptedFileSize, info.Size()) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) result, err := client.ReadDir(".") assert.NoError(t, err) if assert.Len(t, result, 1) { assert.Equal(t, testFileSize, result[0].Size()) } info, err = client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, testFileSize, info.Size()) } err = client.Remove(testFileName) assert.NoError(t, err) _, err = client.Lstat(testFileName) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize-encryptedFileSize, user.UsedQuotaSize) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestOpenReadWriteCryptoFs(t *testing.T) { // read and write is not supported on crypto fs usePubKey := false u := getTestUserWithCryptFs(usePubKey) u.QuotaSize = 6553600 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("sample test data") n, err := sftpFile.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) buffer := make([]byte, 128) _, err = sftpFile.ReadAt(buffer, 1) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = sftpFile.Close() assert.NoError(t, err) } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestEmptyFile(t *testing.T) { usePubKey := true u := getTestUserWithCryptFs(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("") n, err := sftpFile.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) err = sftpFile.Close() assert.NoError(t, err) } info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, int64(0), info.Size()) } localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, 0, client) assert.NoError(t, err) encryptedFileSize, err := getEncryptedFileSize(0) assert.NoError(t, err) info, err = os.Stat(filepath.Join(user.HomeDir, testFileName)) if assert.NoError(t, err) { assert.Equal(t, encryptedFileSize, info.Size()) } err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUploadResumeCryptFs(t *testing.T) { // resuming uploads is not supported usePubKey := true u := getTestUserWithCryptFs(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) appendDataSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = appendToTestFile(testFilePath, appendDataSize) assert.NoError(t, err) err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize, false, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestQuotaFileReplaceCryptFs(t *testing.T) { usePubKey := false u := getTestUserWithCryptFs(usePubKey) u.QuotaFiles = 1000 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) testFileSize := int64(65535) testFilePath := filepath.Join(homeBasePath, testFileName) encryptedFileSize, err := getEncryptedFileSize(testFileSize) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { //nolint:dupl defer conn.Close() defer client.Close() expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize expectedQuotaFiles := user.UsedQuotaFiles + 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) // now replace the same file, the quota must not change err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) // now create a symlink, replace it with a file and check the quota // replacing a symlink is like uploading a new file err = client.Symlink(testFileName, testFileName+".link") //nolint:goconst assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) expectedQuotaFiles = expectedQuotaFiles + 1 expectedQuotaSize = expectedQuotaSize + encryptedFileSize err = sftpUploadFile(testFilePath, testFileName+".link", testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) } // now set a quota size restriction and upload the same file, upload should fail for space limit exceeded user.QuotaSize = encryptedFileSize*2 - 1 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.Error(t, err, "quota size exceeded, file upload must fail") err = client.Remove(testFileName) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestQuotaScanCryptFs(t *testing.T) { usePubKey := false user, _, err := httpdtest.AddUser(getTestUserWithCryptFs(usePubKey), http.StatusCreated) assert.NoError(t, err) testFileSize := int64(65535) encryptedFileSize, err := getEncryptedFileSize(testFileSize) assert.NoError(t, err) expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize expectedQuotaFiles := user.UsedQuotaFiles + 1 conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) // create user with the same home dir, so there is at least an untracked file user, _, err = httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) assert.NoError(t, err) assert.Eventually(t, func() bool { scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) if err == nil { return len(scans) == 0 } return false }, 1*time.Second, 50*time.Millisecond) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestGetMimeTypeCryptFs(t *testing.T) { usePubKey := true user, _, err := httpdtest.AddUser(getTestUserWithCryptFs(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("some UTF-8 text so we should get a text/plain mime type") n, err := sftpFile.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) err = sftpFile.Close() assert.NoError(t, err) } } user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(testPassphrase) fs, err := user.GetFilesystem("connID") if assert.NoError(t, err) { assert.True(t, vfs.IsCryptOsFs(fs)) mime, err := fs.GetMimeType(filepath.Join(user.GetHomeDir(), testFileName)) assert.NoError(t, err) assert.Equal(t, "text/plain; charset=utf-8", mime) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestTruncate(t *testing.T) { // truncate is not supported usePubKey := true user, _, err := httpdtest.AddUser(getTestUserWithCryptFs(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() f, err := client.OpenFile(testFileName, os.O_WRONLY|os.O_CREATE) if assert.NoError(t, err) { err = f.Truncate(0) assert.NoError(t, err) err = f.Truncate(1) assert.Error(t, err) } err = f.Close() assert.NoError(t, err) err = client.Truncate(testFileName, 0) assert.Error(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSCPBasicHandlingCryptoFs(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUserWithCryptFs(usePubKey) u.QuotaSize = 6553600 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(131074) encryptedFileSize, err := getEncryptedFileSize(testFileSize) assert.NoError(t, err) expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize expectedQuotaFiles := user.UsedQuotaFiles + 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) localPath := filepath.Join(homeBasePath, "scp_download.dat") // test to download a missing file err = scpDownload(localPath, remoteDownPath, false, false) assert.Error(t, err, "downloading a missing file via scp must fail") err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) err = scpDownload(localPath, remoteDownPath, false, false) assert.NoError(t, err) fi, err := os.Stat(localPath) if assert.NoError(t, err) { assert.Equal(t, testFileSize, fi.Size()) } fi, err = os.Stat(filepath.Join(user.GetHomeDir(), testFileName)) if assert.NoError(t, err) { assert.Equal(t, encryptedFileSize, fi.Size()) } err = os.Remove(localPath) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) // now overwrite the existing file err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } func TestSCPRecursiveCryptFs(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUserWithCryptFs(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testBaseDirName := "atestdir" testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) testBaseDirDownName := "test_dir_down" //nolint:goconst testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName) testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName) testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName) testFileSize := int64(131074) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize) assert.NoError(t, err) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testBaseDirName)) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") err = scpUpload(testBaseDirPath, remoteUpPath, true, false) assert.NoError(t, err) // overwrite existing dir err = scpUpload(testBaseDirPath, remoteUpPath, true, false) assert.NoError(t, err) err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true) assert.NoError(t, err) // test download without passing -r err = scpDownload(testBaseDirDownPath, remoteDownPath, true, false) assert.Error(t, err, "recursive download without -r must fail") fi, err := os.Stat(filepath.Join(testBaseDirDownPath, testFileName)) if assert.NoError(t, err) { assert.Equal(t, testFileSize, fi.Size()) } fi, err = os.Stat(filepath.Join(testBaseDirDownPath, testBaseDirName, testFileName)) if assert.NoError(t, err) { assert.Equal(t, testFileSize, fi.Size()) } // upload to a non existent dir remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/non_existent_dir") err = scpUpload(testBaseDirPath, remoteUpPath, true, false) assert.Error(t, err, "uploading via scp to a non existent dir must fail") err = os.RemoveAll(testBaseDirPath) assert.NoError(t, err) err = os.RemoveAll(testBaseDirDownPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func getEncryptedFileSize(size int64) (int64, error) { encSize, err := sio.EncryptedSize(uint64(size)) return int64(encSize) + 33, err } func getTestUserWithCryptFs(usePubKey bool) dataprovider.User { u := getTestUser(usePubKey) u.FsConfig.Provider = sdk.CryptedFilesystemProvider u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(testPassphrase) return u } ================================================ FILE: internal/sftpd/handler.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd import ( "io" "net" "os" "path" "strings" "time" "github.com/pkg/sftp" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) // Connection details for an authenticated user type Connection struct { *common.BaseConnection // client's version string ClientVersion string // Remote address for this connection RemoteAddr net.Addr LocalAddr net.Addr channel io.ReadWriteCloser command string } // GetClientVersion returns the connected client's version func (c *Connection) GetClientVersion() string { return c.ClientVersion } // GetLocalAddress returns local connection address func (c *Connection) GetLocalAddress() string { if c.LocalAddr == nil { return "" } return c.LocalAddr.String() } // GetRemoteAddress returns the connected client's address func (c *Connection) GetRemoteAddress() string { if c.RemoteAddr == nil { return "" } return c.RemoteAddr.String() } // GetCommand returns the SSH command, if any func (c *Connection) GetCommand() string { return c.command } // Fileread creates a reader for a file on the system and returns the reader back. func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { c.UpdateLastActivity() updateRequestPaths(request) if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { c.Log(logger.LevelInfo, "denying file read due to transfer count limits") return nil, c.GetPermissionDeniedError() } transferQuota := c.GetTransferQuota() if !transferQuota.HasDownloadSpace() { c.Log(logger.LevelInfo, "denying file read due to quota limits") return nil, c.GetReadQuotaExceededError() } if ok, policy := c.User.IsFileAllowed(request.Filepath); !ok { c.Log(logger.LevelWarn, "reading file %q is not allowed", request.Filepath) return nil, c.GetErrorForDeniedFile(policy) } fs, p, err := c.GetFsAndResolvedPath(request.Filepath) if err != nil { return nil, err } if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreDownload, p, request.Filepath, 0, 0); err != nil { c.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", request.Filepath, err) return nil, c.GetPermissionDeniedError() } file, r, cancelFn, err := fs.Open(p, 0) if err != nil { c.Log(logger.LevelError, "could not open file %q for reading: %+v", p, err) return nil, c.GetFsError(fs, err) } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota) t := newTransfer(baseTransfer, nil, r, nil) return t, nil } // OpenFile implements OpenFileWriter interface func (c *Connection) OpenFile(request *sftp.Request) (sftp.WriterAtReaderAt, error) { return c.handleFilewrite(request) } // Filewrite handles the write actions for a file on the system. func (c *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { return c.handleFilewrite(request) } func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReaderAt, error) { //nolint:gocyclo c.UpdateLastActivity() updateRequestPaths(request) if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { c.Log(logger.LevelInfo, "denying file write due to transfer count limits") return nil, c.GetPermissionDeniedError() } if ok, _ := c.User.IsFileAllowed(request.Filepath); !ok { c.Log(logger.LevelWarn, "writing file %q is not allowed", request.Filepath) return nil, c.GetPermissionDeniedError() } fs, p, err := c.GetFsAndResolvedPath(request.Filepath) if err != nil { return nil, err } filePath := p if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { filePath = fs.GetAtomicUploadPath(p) } var errForRead error if !vfs.HasOpenRWSupport(fs) && request.Pflags().Read { // read and write mode is only supported for local filesystem errForRead = sftp.ErrSSHFxOpUnsupported } if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(request.Filepath)) { // we can try to read only for local fs here, see above. // os.ErrPermission will become sftp.ErrSSHFxPermissionDenied when sent to // the client errForRead = os.ErrPermission } stat, statErr := fs.Lstat(p) if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } return c.handleSFTPUploadToNewFile(fs, request.Pflags(), p, filePath, request.Filepath, errForRead) } if statErr != nil { c.Log(logger.LevelError, "error performing file stat %q: %+v", p, statErr) return nil, c.GetFsError(fs, statErr) } // This happen if we upload a file that has the same name of an existing directory if stat.IsDir() { c.Log(logger.LevelError, "attempted to open a directory for writing to: %q", p) return nil, sftp.ErrSSHFxOpUnsupported } if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } return c.handleSFTPUploadToExistingFile(fs, request.Pflags(), p, filePath, stat.Size(), request.Filepath, errForRead) } // Filecmd hander for basic SFTP system calls related to files, but not anything to do with reading // or writing to those files. func (c *Connection) Filecmd(request *sftp.Request) error { c.UpdateLastActivity() updateRequestPaths(request) switch request.Method { case "Setstat": return c.handleSFTPSetstat(request) case "Rename": if err := c.Rename(request.Filepath, request.Target); err != nil { return err } case "Rmdir": return c.RemoveDir(request.Filepath) case "Mkdir": err := c.CreateDir(request.Filepath, true) if err != nil { return err } case "Symlink": if err := c.CreateSymlink(request.Filepath, request.Target); err != nil { return err } case "Remove": return c.handleSFTPRemove(request) default: return sftp.ErrSSHFxOpUnsupported } return sftp.ErrSSHFxOk } // Filelist is the handler for SFTP filesystem list calls. This will handle calls to list the contents of // a directory as well as perform file/folder stat calls. func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { c.UpdateLastActivity() updateRequestPaths(request) switch request.Method { case "List": lister, err := c.ListDir(request.Filepath) if err != nil { return nil, err } modTime := time.Unix(0, 0) if request.Filepath != "/" { lister.Prepend(vfs.NewFileInfo("..", true, 0, modTime, false)) } lister.Prepend(vfs.NewFileInfo(".", true, 0, modTime, false)) return lister, nil case "Stat": if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } s, err := c.DoStat(request.Filepath, 0, true) if err != nil { return nil, err } return listerAt([]os.FileInfo{s}), nil default: return nil, sftp.ErrSSHFxOpUnsupported } } // Readlink implements the ReadlinkFileLister interface func (c *Connection) Readlink(filePath string) (string, error) { filePath = util.CleanPath(filePath) if err := c.canReadLink(filePath); err != nil { return "", err } fs, p, err := c.GetFsAndResolvedPath(filePath) if err != nil { return "", err } s, err := fs.Readlink(p) if err != nil { c.Log(logger.LevelDebug, "error running readlink on path %q: %+v", p, err) return "", c.GetFsError(fs, err) } if err := c.canReadLink(s); err != nil { return "", err } return s, nil } // Lstat implements LstatFileLister interface func (c *Connection) Lstat(request *sftp.Request) (sftp.ListerAt, error) { c.UpdateLastActivity() updateRequestPaths(request) if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } s, err := c.DoStat(request.Filepath, 1, true) if err != nil { return nil, err } return listerAt([]os.FileInfo{s}), nil } // RealPath implements the RealPathFileLister interface func (c *Connection) RealPath(p string) (string, error) { if c.User.Filters.StartDirectory == "" { p = util.CleanPath(p) } else { p = util.CleanPathWithBase(c.User.Filters.StartDirectory, p) } if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(p)) { return "", sftp.ErrSSHFxPermissionDenied } fs, fsPath, err := c.GetFsAndResolvedPath(p) if err != nil { return "", err } if realPather, ok := fs.(vfs.FsRealPather); ok { realPath, err := realPather.RealPath(fsPath) if err != nil { return "", c.GetFsError(fs, err) } return realPath, nil } return p, nil } // StatVFS implements StatVFSFileCmder interface func (c *Connection) StatVFS(r *sftp.Request) (*sftp.StatVFS, error) { c.UpdateLastActivity() updateRequestPaths(r) // we are assuming that r.Filepath is a dir, this could be wrong but should // not produce any side effect here. // we don't consider c.User.Filters.MaxUploadFileSize, we return disk stats here // not the limit for a single file upload quotaResult, _ := c.HasSpace(true, true, path.Join(r.Filepath, "fakefile.txt")) fs, p, err := c.GetFsAndResolvedPath(r.Filepath) if err != nil { return nil, err } if !quotaResult.HasSpace { return c.getStatVFSFromQuotaResult(fs, p, quotaResult) } if quotaResult.QuotaSize == 0 && quotaResult.QuotaFiles == 0 { // no quota restrictions statvfs, err := fs.GetAvailableDiskSize(p) if err == vfs.ErrStorageSizeUnavailable { return c.getStatVFSFromQuotaResult(fs, p, quotaResult) } return statvfs, err } // there is free space but some limits are configured return c.getStatVFSFromQuotaResult(fs, p, quotaResult) } func (c *Connection) canReadLink(name string) error { if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(name)) { return sftp.ErrSSHFxPermissionDenied } ok, policy := c.User.IsFileAllowed(name) if !ok && policy == sdk.DenyPolicyHide { return sftp.ErrSSHFxNoSuchFile } return nil } func (c *Connection) handleSFTPSetstat(request *sftp.Request) error { attrs := common.StatAttributes{ Flags: 0, } if request.Attributes() != nil { if request.AttrFlags().Permissions { attrs.Flags |= common.StatAttrPerms attrs.Mode = request.Attributes().FileMode() } if request.AttrFlags().UidGid { attrs.Flags |= common.StatAttrUIDGID attrs.UID = int(request.Attributes().UID) attrs.GID = int(request.Attributes().GID) } if request.AttrFlags().Acmodtime { attrs.Flags |= common.StatAttrTimes attrs.Atime = time.Unix(int64(request.Attributes().Atime), 0) attrs.Mtime = time.Unix(int64(request.Attributes().Mtime), 0) } if request.AttrFlags().Size { attrs.Flags |= common.StatAttrSize attrs.Size = int64(request.Attributes().Size) } } return c.SetStat(request.Filepath, &attrs) } func (c *Connection) handleSFTPRemove(request *sftp.Request) error { fs, fsPath, err := c.GetFsAndResolvedPath(request.Filepath) if err != nil { return err } var fi os.FileInfo if fi, err = fs.Lstat(fsPath); err != nil { c.Log(logger.LevelDebug, "failed to remove file %q: stat error: %+v", fsPath, err) return c.GetFsError(fs, err) } if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { c.Log(logger.LevelDebug, "cannot remove %q is not a file/symlink", fsPath) return sftp.ErrSSHFxFailure } return c.RemoveFile(fs, fsPath, request.Filepath, fi) } func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, pflags sftp.FileOpenFlags, resolvedPath, filePath, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { diskQuota, transferQuota := c.HasSpace(true, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, c.GetQuotaExceededError() } if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, 0, 0); err != nil { c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) return nil, c.GetPermissionDeniedError() } osFlags := getOSOpenFlags(pflags) file, w, cancelFn, err := fs.Create(filePath, osFlags, c.GetCreateChecks(requestPath, true, false)) if err != nil { c.Log(logger.LevelError, "error creating file %q, os flags %d, pflags %+v: %+v", resolvedPath, osFlags, pflags, err) return nil, c.GetFsError(fs, err) } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) t := newTransfer(baseTransfer, w, nil, errForRead) return t, nil } func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileOpenFlags, resolvedPath, filePath string, fileSize int64, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { var err error diskQuota, transferQuota := c.HasSpace(false, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, c.GetQuotaExceededError() } osFlags := getOSOpenFlags(pflags) minWriteOffset := int64(0) isTruncate := osFlags&os.O_TRUNC != 0 // for upload resumes OpenSSH sets the APPEND flag while WinSCP does not set it, // so we suppose this is an upload resume if the TRUNCATE flag is not set isResume := !isTruncate // if there is a size limit the remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before. // For Cloud FS GetMaxWriteSize will return unsupported operation maxWriteSize, err := c.GetMaxWriteSize(diskQuota, isResume, fileSize, vfs.IsUploadResumeSupported(fs, fileSize)) if err != nil { c.Log(logger.LevelDebug, "unable to get max write size for file %q is resume? %t: %v", requestPath, isResume, err) return nil, err } if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, fileSize, osFlags); err != nil { c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) return nil, c.GetPermissionDeniedError() } if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { _, _, err = fs.Rename(resolvedPath, filePath, 0) if err != nil { c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v", resolvedPath, filePath, err) return nil, c.GetFsError(fs, err) } } file, w, cancelFn, err := fs.Create(filePath, osFlags, c.GetCreateChecks(requestPath, false, isResume)) if err != nil { c.Log(logger.LevelError, "error opening existing file, os flags %v, pflags: %+v, source: %q, err: %+v", osFlags, pflags, filePath, err) return nil, c.GetFsError(fs, err) } initialSize := int64(0) truncatedSize := int64(0) // bytes truncated and not included in quota if isResume { c.Log(logger.LevelDebug, "resuming upload requested, file path %q initial size: %d, has append flag %t", filePath, fileSize, pflags.Append) // enforce min write offset only if the client passed the APPEND flag or the filesystem // supports emulated resume if pflags.Append || !fs.IsUploadResumeSupported() { minWriteOffset = fileSize } initialSize = fileSize } else { if isTruncate && vfs.HasTruncateSupport(fs) { c.updateQuotaAfterTruncate(requestPath, fileSize) } else { initialSize = fileSize truncatedSize = fileSize } } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) t := newTransfer(baseTransfer, w, nil, errForRead) return t, nil } // Disconnect disconnects the client by closing the channel func (c *Connection) Disconnect() error { if c.channel == nil { c.Log(logger.LevelWarn, "cannot disconnect a nil channel") return nil } return c.channel.Close() } func (c *Connection) getStatVFSFromQuotaResult(fs vfs.Fs, name string, quotaResult vfs.QuotaCheckResult) (*sftp.StatVFS, error) { s, err := fs.GetAvailableDiskSize(name) if err == nil { if quotaResult.QuotaSize == 0 || quotaResult.QuotaSize > int64(s.TotalSpace()) { quotaResult.QuotaSize = int64(s.TotalSpace()) } if quotaResult.QuotaFiles == 0 || quotaResult.QuotaFiles > int(s.Files) { quotaResult.QuotaFiles = int(s.Files) } } else if err != vfs.ErrStorageSizeUnavailable { return nil, err } // if we are unable to get quota size or quota files we add some arbitrary values if quotaResult.QuotaSize == 0 { quotaResult.QuotaSize = quotaResult.UsedSize + 8*1024*1024*1024*1024 // 8TB } if quotaResult.QuotaFiles == 0 { quotaResult.QuotaFiles = quotaResult.UsedFiles + 1000000 // 1 million } bsize := uint64(4096) for bsize > uint64(quotaResult.QuotaSize) { bsize /= 4 } blocks := uint64(quotaResult.QuotaSize) / bsize bfree := uint64(quotaResult.QuotaSize-quotaResult.UsedSize) / bsize files := uint64(quotaResult.QuotaFiles) ffree := uint64(quotaResult.QuotaFiles - quotaResult.UsedFiles) if !quotaResult.HasSpace { bfree = 0 ffree = 0 } return &sftp.StatVFS{ Bsize: bsize, Frsize: bsize, Blocks: blocks, Bfree: bfree, Bavail: bfree, Files: files, Ffree: ffree, Favail: ffree, Namemax: 255, }, nil } func (c *Connection) updateQuotaAfterTruncate(requestPath string, fileSize int64) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false) return } dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) { var osFlags int if requestFlags.Read && requestFlags.Write { osFlags |= os.O_RDWR } else if requestFlags.Write { osFlags |= os.O_WRONLY } // we ignore Append flag since pkg/sftp use WriteAt that cannot work with os.O_APPEND /*if requestFlags.Append { osFlags |= os.O_APPEND }*/ if requestFlags.Creat { osFlags |= os.O_CREATE } if requestFlags.Trunc { osFlags |= os.O_TRUNC } if requestFlags.Excl { osFlags |= os.O_EXCL } return osFlags } func updateRequestPaths(request *sftp.Request) { if request.Method == "Symlink" { request.Filepath = path.Clean(strings.ReplaceAll(request.Filepath, "\\", "/")) } else { request.Filepath = util.CleanPath(request.Filepath) } if request.Target != "" { request.Target = util.CleanPath(request.Target) } } ================================================ FILE: internal/sftpd/httpfs_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd_test import ( "fmt" "io/fs" "math" "net/http" "net/url" "os" "path" "path/filepath" "runtime" "testing" "time" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpdtest" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( httpFsPort = 12345 defaultHTTPFsUsername = "httpfs_user" ) var ( httpFsSocketPath = filepath.Join(os.TempDir(), "httpfs.sock") ) func TestBasicHTTPFsHandling(t *testing.T) { usePubKey := true u := getTestUserWithHTTPFs(usePubKey) u.QuotaSize = 6553600 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) expectedQuotaSize := user.UsedQuotaSize + testFileSize*2 expectedQuotaFiles := user.UsedQuotaFiles + 2 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client) assert.Error(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, testFileSize, info.Size()) } contents, err := client.ReadDir("/") assert.NoError(t, err) if assert.Len(t, contents, 1) { assert.Equal(t, testFileName, contents[0].Name()) } dirName := "test dirname" err = client.Mkdir(dirName) assert.NoError(t, err) contents, err = client.ReadDir(".") assert.NoError(t, err) assert.Len(t, contents, 2) contents, err = client.ReadDir(dirName) assert.NoError(t, err) assert.Len(t, contents, 0) err = sftpUploadFile(testFilePath, path.Join(dirName, testFileName), testFileSize, client) assert.NoError(t, err) contents, err = client.ReadDir(dirName) assert.NoError(t, err) assert.Len(t, contents, 1) dirRenamed := dirName + "_renamed" err = client.Rename(dirName, dirRenamed) assert.NoError(t, err) info, err = client.Stat(dirRenamed) if assert.NoError(t, err) { assert.True(t, info.IsDir()) } // mode 0666 and 0444 works on Windows too newPerm := os.FileMode(0444) err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) info, err = client.Stat(testFileName) assert.NoError(t, err) assert.Equal(t, newPerm, info.Mode().Perm()) newPerm = os.FileMode(0666) err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) info, err = client.Stat(testFileName) assert.NoError(t, err) assert.Equal(t, newPerm, info.Mode().Perm()) // chtimes acmodTime := time.Now().Add(-36 * time.Hour) err = client.Chtimes(testFileName, acmodTime, acmodTime) assert.NoError(t, err) info, err = client.Stat(testFileName) if assert.NoError(t, err) { diff := math.Abs(info.ModTime().Sub(acmodTime).Seconds()) assert.LessOrEqual(t, diff, float64(1)) } _, err = client.StatVFS("/") assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) // execute a quota scan _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) assert.NoError(t, err) assert.Eventually(t, func() bool { scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) if err == nil { return len(scans) == 0 } return false }, 1*time.Second, 50*time.Millisecond) err = client.Remove(testFileName) assert.NoError(t, err) _, err = client.Lstat(testFileName) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize-testFileSize, user.UsedQuotaSize) // truncate err = client.Truncate(path.Join(dirRenamed, testFileName), 100) assert.NoError(t, err) info, err = client.Stat(path.Join(dirRenamed, testFileName)) if assert.NoError(t, err) { assert.Equal(t, int64(100), info.Size()) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, int64(100), user.UsedQuotaSize) // update quota _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) assert.NoError(t, err) assert.Eventually(t, func() bool { scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) if err == nil { return len(scans) == 0 } return false }, 1*time.Second, 50*time.Millisecond) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, int64(100), user.UsedQuotaSize) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestHTTPFsVirtualFolder(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) folderName := "httpfsfolder" vdirPath := "/vdir/http fs" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, }) f := vfs.BaseVirtualFolder{ Name: folderName, FsConfig: vfs.Filesystem{ Provider: sdk.HTTPFilesystemProvider, HTTPConfig: vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: fmt.Sprintf("http://127.0.0.1:%d/api/v1", httpFsPort), Username: defaultHTTPFsUsername, EqualityCheckMode: 1, }, }, }, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client) assert.NoError(t, err) _, err = client.Stat(path.Join(vdirPath, testFileName)) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) } func TestHTTPFsWalk(t *testing.T) { user := getTestUserWithHTTPFs(false) user.FsConfig.HTTPConfig.EqualityCheckMode = 1 httpFs, err := user.GetFilesystem("") require.NoError(t, err) basePath := filepath.Join(os.TempDir(), "httpfs", user.FsConfig.HTTPConfig.Username) err = os.RemoveAll(basePath) assert.NoError(t, err) var walkedPaths []string err = httpFs.Walk("/", func(walkedPath string, _ fs.FileInfo, err error) error { if err != nil { return err } walkedPaths = append(walkedPaths, httpFs.GetRelativePath(walkedPath)) return nil }) require.NoError(t, err) require.Len(t, walkedPaths, 1) require.Contains(t, walkedPaths, "/") // now add some files/folders for i := 0; i < 10; i++ { err = os.WriteFile(filepath.Join(basePath, fmt.Sprintf("file%d", i)), nil, os.ModePerm) assert.NoError(t, err) err = os.Mkdir(filepath.Join(basePath, fmt.Sprintf("dir%d", i)), os.ModePerm) assert.NoError(t, err) for j := 0; j < 5; j++ { err = os.WriteFile(filepath.Join(basePath, fmt.Sprintf("dir%d", i), fmt.Sprintf("subfile%d", j)), nil, os.ModePerm) assert.NoError(t, err) } } walkedPaths = nil err = httpFs.Walk("/", func(walkedPath string, _ fs.FileInfo, err error) error { if err != nil { return err } walkedPaths = append(walkedPaths, httpFs.GetRelativePath(walkedPath)) return nil }) require.NoError(t, err) require.Len(t, walkedPaths, 71) require.Contains(t, walkedPaths, "/") for i := 0; i < 10; i++ { require.Contains(t, walkedPaths, path.Join("/", fmt.Sprintf("file%d", i))) require.Contains(t, walkedPaths, path.Join("/", fmt.Sprintf("dir%d", i))) for j := 0; j < 5; j++ { require.Contains(t, walkedPaths, path.Join("/", fmt.Sprintf("dir%d", i), fmt.Sprintf("subfile%d", j))) } } err = os.RemoveAll(basePath) assert.NoError(t, err) } func TestHTTPFsOverUNIXSocket(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("UNIX domain sockets are not supported on Windows") } assert.Eventually(t, func() bool { _, err := os.Stat(httpFsSocketPath) return err == nil }, 1*time.Second, 50*time.Millisecond) usePubKey := true u := getTestUserWithHTTPFs(usePubKey) u.FsConfig.HTTPConfig.Endpoint = fmt.Sprintf("http://unix?socket_path=%s&api_prefix=%s", url.QueryEscape(httpFsSocketPath), url.QueryEscape("/api/v1")) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) err = client.Mkdir(testFileName) assert.NoError(t, err) err = client.RemoveDirectory(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func getTestUserWithHTTPFs(usePubKey bool) dataprovider.User { u := getTestUser(usePubKey) u.FsConfig.Provider = sdk.HTTPFilesystemProvider u.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: fmt.Sprintf("http://127.0.0.1:%d/api/v1", httpFsPort), Username: defaultHTTPFsUsername, }, } return u } func startHTTPFs() { if runtime.GOOS != osWindows { go func() { if err := httpdtest.StartTestHTTPFsOverUnixSocket(httpFsSocketPath); err != nil { logger.ErrorToConsole("could not start HTTPfs test server over UNIX socket: %v", err) os.Exit(1) } }() } go func() { if err := httpdtest.StartTestHTTPFs(httpFsPort, nil); err != nil { logger.ErrorToConsole("could not start HTTPfs test server: %v", err) os.Exit(1) } }() waitTCPListening(fmt.Sprintf(":%d", httpFsPort)) } ================================================ FILE: internal/sftpd/internal_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd import ( "bytes" "context" "errors" "fmt" "io" "io/fs" "net" "os" "path/filepath" "runtime" "slices" "testing" "time" "github.com/eikenb/pipeat" "github.com/pkg/sftp" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( osWindows = "windows" ) var ( configDir = filepath.Join(".", "..", "..") ) type MockChannel struct { Buffer *bytes.Buffer StdErrBuffer *bytes.Buffer ReadError error WriteError error ShortWriteErr bool } func (c *MockChannel) Read(data []byte) (int, error) { if c.ReadError != nil { return 0, c.ReadError } return c.Buffer.Read(data) } func (c *MockChannel) Write(data []byte) (int, error) { if c.WriteError != nil { return 0, c.WriteError } if c.ShortWriteErr { return 0, nil } return c.Buffer.Write(data) } func (c *MockChannel) Close() error { return nil } func (c *MockChannel) CloseWrite() error { return nil } func (c *MockChannel) SendRequest(_ string, _ bool, _ []byte) (bool, error) { return true, nil } func (c *MockChannel) Stderr() io.ReadWriter { return c.StdErrBuffer } // MockOsFs mockable OsFs type MockOsFs struct { vfs.Fs err error statErr error isAtomicUploadSupported bool } // Name returns the name for the Fs implementation func (fs MockOsFs) Name() string { return "mockOsFs" } // IsUploadResumeSupported returns true if resuming uploads is supported func (MockOsFs) IsUploadResumeSupported() bool { return false } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (MockOsFs) IsConditionalUploadResumeSupported(_ int64) bool { return false } // IsAtomicUploadSupported returns true if atomic upload is supported func (fs MockOsFs) IsAtomicUploadSupported() bool { return fs.isAtomicUploadSupported } // Stat returns a FileInfo describing the named file func (fs MockOsFs) Stat(name string) (os.FileInfo, error) { if fs.statErr != nil { return nil, fs.statErr } return os.Stat(name) } // Lstat returns a FileInfo describing the named file func (fs MockOsFs) Lstat(name string) (os.FileInfo, error) { if fs.statErr != nil { return nil, fs.statErr } return os.Lstat(name) } // Remove removes the named file or (empty) directory. func (fs MockOsFs) Remove(name string, _ bool) error { if fs.err != nil { return fs.err } return os.Remove(name) } // Rename renames (moves) source to target func (fs MockOsFs) Rename(source, target string, _ int) (int, int64, error) { if fs.err != nil { return -1, -1, fs.err } err := os.Rename(source, target) return -1, -1, err } func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs { return &MockOsFs{ Fs: vfs.NewOsFs(connectionID, rootDir, "", nil), err: err, statErr: statErr, isAtomicUploadSupported: atomicUpload, } } func TestRemoveNonexistentQuotaScan(t *testing.T) { assert.False(t, common.QuotaScans.RemoveUserQuotaScan("username")) } func TestGetOSOpenFlags(t *testing.T) { var flags sftp.FileOpenFlags flags.Write = true flags.Excl = true osFlags := getOSOpenFlags(flags) assert.NotEqual(t, 0, osFlags&os.O_WRONLY) assert.NotEqual(t, 0, osFlags&os.O_EXCL) flags.Append = true // append flag should be ignored to allow resume assert.NotEqual(t, 0, osFlags&os.O_WRONLY) assert.NotEqual(t, 0, osFlags&os.O_EXCL) } func TestUploadResumeInvalidOffset(t *testing.T) { testfile := "testfile" //nolint:goconst file, err := os.Create(testfile) assert.NoError(t, err) user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "testuser", }, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferUpload, 10, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) _, err = transfer.WriteAt([]byte("test"), 0) assert.Error(t, err, "upload with invalid offset must fail") if assert.Error(t, transfer.ErrTransfer) { assert.EqualError(t, err, transfer.ErrTransfer.Error()) assert.Contains(t, transfer.ErrTransfer.Error(), "invalid write offset") } err = transfer.Close() if assert.Error(t, err) { assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) } err = os.Remove(testfile) assert.NoError(t, err) } func TestReadWriteErrors(t *testing.T) { testfile := "testfile" file, err := os.Create(testfile) assert.NoError(t, err) user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "testuser", }, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) err = file.Close() assert.NoError(t, err) _, err = transfer.WriteAt([]byte("test"), 0) assert.Error(t, err, "writing to closed file must fail") buf := make([]byte, 32768) _, err = transfer.ReadAt(buf, 0) assert.Error(t, err, "reading from a closed file must fail") err = transfer.Close() assert.Error(t, err) r, _, err := pipeat.Pipe() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer = newTransfer(baseTransfer, nil, vfs.NewPipeReader(r), nil) err = transfer.Close() assert.NoError(t, err) _, err = transfer.ReadAt(buf, 0) assert.Error(t, err, "reading from a closed pipe must fail") r, w, err := pipeat.Pipe() assert.NoError(t, err) pipeWriter := vfs.NewPipeWriter(w) baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer = newTransfer(baseTransfer, pipeWriter, nil, nil) err = r.Close() assert.NoError(t, err) errFake := fmt.Errorf("fake upload error") go func() { time.Sleep(100 * time.Millisecond) pipeWriter.Done(errFake) }() err = transfer.closeIO() assert.EqualError(t, err, errFake.Error()) _, err = transfer.WriteAt([]byte("test"), 0) assert.Error(t, err, "writing to closed pipe must fail") err = transfer.BaseTransfer.Close() assert.EqualError(t, err, errFake.Error()) err = os.Remove(testfile) assert.NoError(t, err) assert.Len(t, conn.GetTransfers(), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestUnsupportedListOP(t *testing.T) { conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", dataprovider.User{}) sftpConn := Connection{ BaseConnection: conn, } request := sftp.NewRequest("Unsupported", "") _, err := sftpConn.Filelist(request) assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) } func TestTransferCancelFn(t *testing.T) { testfile := "testfile" file, err := os.Create(testfile) assert.NoError(t, err) isCancelled := false cancelFn := func() { isCancelled = true } user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "testuser", }, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) errFake := errors.New("fake error, this will trigger cancelFn") transfer.TransferError(errFake) err = transfer.Close() if assert.Error(t, err) { assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) } if assert.Error(t, transfer.ErrTransfer) { assert.EqualError(t, transfer.ErrTransfer, errFake.Error()) } assert.True(t, isCancelled, "cancelFn not called!") err = os.Remove(testfile) assert.NoError(t, err) } func TestUploadFiles(t *testing.T) { common.Config.UploadMode = common.UploadModeAtomic fs := vfs.NewOsFs("123", os.TempDir(), "", nil) u := dataprovider.User{} c := Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", u), } var flags sftp.FileOpenFlags flags.Write = true flags.Trunc = true _, err := c.handleSFTPUploadToExistingFile(fs, flags, "missing_path", "other_missing_path", 0, "/missing_path", nil) assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid") common.Config.UploadMode = common.UploadModeStandard _, err = c.handleSFTPUploadToExistingFile(fs, flags, "missing_path", "other_missing_path", 0, "/missing_path", nil) assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid") missingFile := "missing/relative/file.txt" if runtime.GOOS == osWindows { missingFile = "missing\\relative\\file.txt" } _, err = c.handleSFTPUploadToNewFile(fs, flags, ".", missingFile, "/missing", nil) assert.Error(t, err, "upload new file in missing path must fail") fs = newMockOsFs(nil, nil, false, "123", os.TempDir()) f, err := os.CreateTemp("", "temp") assert.NoError(t, err) err = f.Close() assert.NoError(t, err) tr, err := c.handleSFTPUploadToExistingFile(fs, flags, f.Name(), f.Name(), 123, f.Name(), nil) if assert.NoError(t, err) { transfer := tr.(*transfer) transfers := c.GetTransfers() if assert.Equal(t, 1, len(transfers)) { assert.Equal(t, transfers[0].ID, transfer.GetID()) assert.Equal(t, int64(123), transfer.InitialSize) err = transfer.Close() assert.NoError(t, err) assert.Equal(t, 0, len(c.GetTransfers())) } } err = os.Remove(f.Name()) assert.NoError(t, err) common.Config.UploadMode = common.UploadModeAtomicWithResume } func TestWithInvalidHome(t *testing.T) { u := dataprovider.User{} u.HomeDir = "home_rel_path" //nolint:goconst _, err := loginUser(&u, dataprovider.LoginMethodPassword, "", nil) assert.Error(t, err, "login a user with an invalid home_dir must fail") u.HomeDir = os.TempDir() fs, err := u.GetFilesystem("123") assert.NoError(t, err) c := Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", u), } resolved, err := fs.ResolvePath("../upper_path") assert.NoError(t, err) assert.Equal(t, filepath.Join(u.HomeDir, "upper_path"), resolved) _, err = c.StatVFS(&sftp.Request{ Method: "StatVFS", Filepath: "../unresolvable-path", }) assert.Error(t, err) } func TestResolveWithRootDir(t *testing.T) { u := dataprovider.User{} if runtime.GOOS == osWindows { u.HomeDir = "C:\\" } else { u.HomeDir = "/" } fs, err := u.GetFilesystem("") assert.NoError(t, err) rel, err := filepath.Rel(u.HomeDir, os.TempDir()) assert.NoError(t, err) p, err := fs.ResolvePath(rel) assert.NoError(t, err, "path %v", p) } func TestSFTPGetUsedQuota(t *testing.T) { u := dataprovider.User{} u.HomeDir = "home_rel_path" u.Username = "test_invalid_user" u.QuotaSize = 4096 u.QuotaFiles = 1 u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} connection := Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", u), } quotaResult, _ := connection.HasSpace(false, false, "/") assert.False(t, quotaResult.HasSpace) } func TestSupportedSSHCommands(t *testing.T) { cmds := GetSupportedSSHCommands() assert.Equal(t, len(supportedSSHCommands), len(cmds)) for _, c := range cmds { assert.True(t, slices.Contains(supportedSSHCommands, c)) } } func TestSSHCommandPath(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, } connection := &Connection{ channel: &mockSSHChannel, BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, "", "", dataprovider.User{}), } sshCommand := sshCommand{ command: "test", connection: connection, args: []string{}, } assert.Equal(t, "", sshCommand.getDestPath()) sshCommand.args = []string{"-t", "/tmp/../path"} assert.Equal(t, "/path", sshCommand.getDestPath()) sshCommand.args = []string{"-t", "/tmp/"} assert.Equal(t, "/tmp/", sshCommand.getDestPath()) sshCommand.args = []string{"-t", "tmp/"} assert.Equal(t, "/tmp/", sshCommand.getDestPath()) sshCommand.args = []string{"-t", "/tmp/../../../path"} assert.Equal(t, "/path", sshCommand.getDestPath()) sshCommand.args = []string{"-t", ".."} assert.Equal(t, "/", sshCommand.getDestPath()) sshCommand.args = []string{"-t", "."} assert.Equal(t, "/", sshCommand.getDestPath()) sshCommand.args = []string{"-t", "//"} assert.Equal(t, "/", sshCommand.getDestPath()) sshCommand.args = []string{"-t", "../.."} assert.Equal(t, "/", sshCommand.getDestPath()) sshCommand.args = []string{"-t", "/.."} assert.Equal(t, "/", sshCommand.getDestPath()) sshCommand.args = []string{"-f", "/a space.txt"} assert.Equal(t, "/a space.txt", sshCommand.getDestPath()) } func TestSSHParseCommandPayload(t *testing.T) { cmd := "command -a -f /ab\\ à/some\\ spaces\\ \\ \\(\\).txt" name, args, _ := parseCommandPayload(cmd) assert.Equal(t, "command", name) assert.Equal(t, 3, len(args)) assert.Equal(t, "/ab à/some spaces ().txt", args[2]) _, _, err := parseCommandPayload("") assert.Error(t, err, "parsing invalid command must fail") } func TestSSHCommandErrors(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, } server, client := net.Pipe() defer func() { err := server.Close() assert.NoError(t, err) }() defer func() { err := client.Close() assert.NoError(t, err) }() user := dataprovider.User{} user.Permissions = map[string][]string{ "/": {dataprovider.PermAny}, } connection := Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, "", "", user), channel: &mockSSHChannel, } cmd := sshCommand{ command: "md5sum", connection: &connection, args: []string{}, } err := cmd.handle() assert.Error(t, err, "ssh command must fail, we are sending a fake error") cmd = sshCommand{ command: "md5sum", connection: &connection, args: []string{"/../../test_file_ftp.dat"}, } err = cmd.handle() assert.Error(t, err, "ssh command must fail, we are requesting an invalid path") user = dataprovider.User{} user.Permissions = map[string][]string{ "/": {dataprovider.PermAny}, } user.HomeDir = filepath.Clean(os.TempDir()) user.QuotaFiles = 1 user.UsedQuotaFiles = 2 cmd.connection.User = user _, err = cmd.connection.User.GetFilesystem("123") assert.NoError(t, err) cmd = sshCommand{ command: "sftpgo-remove", connection: &connection, args: []string{"/../../src"}, } err = cmd.handle() assert.Error(t, err, "ssh command must fail, we are requesting an invalid path") cmd = sshCommand{ command: "sftpgo-copy", connection: &connection, args: []string{"/../../test_src", "."}, } err = cmd.handle() assert.Error(t, err, "ssh command must fail, we are requesting an invalid path") err = common.Initialize(common.Config, 0) assert.NoError(t, err) } func TestCommandsWithExtensionsFilter(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), } server, client := net.Pipe() defer server.Close() defer client.Close() user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "test", HomeDir: os.TempDir(), Status: 1, }, } user.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/subdir", AllowedPatterns: []string{".jpg"}, DeniedPatterns: []string{}, }, } connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, "", "", user), channel: &mockSSHChannel, } cmd := sshCommand{ command: "md5sum", connection: connection, args: []string{"subdir/test.png"}, } err := cmd.handleHashCommands() assert.EqualError(t, err, common.ErrPermissionDenied.Error()) } func TestSSHCommandsRemoteFs(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), } user := dataprovider.User{} user.FsConfig = vfs.Filesystem{ Provider: sdk.S3FilesystemProvider, S3Config: vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: "s3bucket", Endpoint: "endpoint", Region: "eu-west-1", }, }, } connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", user), channel: &mockSSHChannel, } cmd := sshCommand{ command: "md5sum", connection: connection, args: []string{}, } err := cmd.handleSFTPGoCopy() assert.Error(t, err) cmd = sshCommand{ command: "sftpgo-remove", connection: connection, args: []string{}, } err = cmd.handleSFTPGoRemove() assert.Error(t, err) } func TestSSHCmdGetFsErrors(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), } user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: "relative path", }, } user.Permissions = map[string][]string{} user.Permissions["/"] = []string{dataprovider.PermAny} connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", user), channel: &mockSSHChannel, } cmd := sshCommand{ command: "sftpgo-remove", connection: connection, args: []string{"path"}, } err := cmd.handleSFTPGoRemove() assert.Error(t, err) cmd = sshCommand{ command: "sftpgo-copy", connection: connection, args: []string{"path1", "path2"}, } err = cmd.handleSFTPGoCopy() assert.Error(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestCommandGetFsError(t *testing.T) { user := dataprovider.User{ FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, }, } buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, } conn := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", user), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: conn, args: []string{"-t", "/tmp"}, }, } err := scpCommand.handleRecursiveUpload() assert.Error(t, err) err = scpCommand.handleDownload("") assert.Error(t, err) } func TestSCPFileMode(t *testing.T) { mode := getFileModeAsString(0, true) assert.Equal(t, "0755", mode) mode = getFileModeAsString(0700, true) assert.Equal(t, "0700", mode) mode = getFileModeAsString(0750, true) assert.Equal(t, "0750", mode) mode = getFileModeAsString(0777, true) assert.Equal(t, "0777", mode) mode = getFileModeAsString(0640, false) assert.Equal(t, "0640", mode) mode = getFileModeAsString(0600, false) assert.Equal(t, "0600", mode) mode = getFileModeAsString(0, false) assert.Equal(t, "0644", mode) fileMode := uint32(0777) fileMode = fileMode | uint32(os.ModeSetgid) fileMode = fileMode | uint32(os.ModeSetuid) fileMode = fileMode | uint32(os.ModeSticky) mode = getFileModeAsString(os.FileMode(fileMode), false) assert.Equal(t, "7777", mode) fileMode = uint32(0644) fileMode = fileMode | uint32(os.ModeSetgid) mode = getFileModeAsString(os.FileMode(fileMode), false) assert.Equal(t, "4644", mode) fileMode = uint32(0600) fileMode = fileMode | uint32(os.ModeSetuid) mode = getFileModeAsString(os.FileMode(fileMode), false) assert.Equal(t, "2600", mode) fileMode = uint32(0044) fileMode = fileMode | uint32(os.ModeSticky) mode = getFileModeAsString(os.FileMode(fileMode), false) assert.Equal(t, "1044", mode) } func TestSCPUploadError(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) writeErr := fmt.Errorf("test write error") mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: writeErr, } user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Join(os.TempDir()), Permissions: make(map[string][]string), }, } user.Permissions["/"] = []string{dataprovider.PermAny} connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", user), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-t", "/"}, }, } err := scpCommand.handle() assert.EqualError(t, err, writeErr.Error()) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer([]byte("D0755 0 testdir\n")), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: writeErr, } err = scpCommand.handleRecursiveUpload() assert.EqualError(t, err, writeErr.Error()) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer([]byte("D0755 a testdir\n")), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: nil, } err = scpCommand.handleRecursiveUpload() assert.Error(t, err) } func TestSCPInvalidEndDir(t *testing.T) { stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer([]byte("E\n")), StdErrBuffer: bytes.NewBuffer(stdErrBuf), } connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: os.TempDir(), }, }), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-t", "/tmp"}, }, } err := scpCommand.handleRecursiveUpload() assert.EqualError(t, err, "unacceptable end dir command") } func TestSCPParseUploadMessage(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: os.TempDir(), }, }), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-t", "/tmp"}, }, } _, _, err := scpCommand.parseUploadMessage(fs, "invalid") assert.Error(t, err, "parsing invalid upload message must fail") _, _, err = scpCommand.parseUploadMessage(fs, "D0755 0") assert.Error(t, err, "parsing incomplete upload message must fail") _, _, err = scpCommand.parseUploadMessage(fs, "D0755 invalidsize testdir") assert.Error(t, err, "parsing upload message with invalid size must fail") _, _, err = scpCommand.parseUploadMessage(fs, "D0755 0 ") assert.Error(t, err, "parsing upload message with invalid name must fail") } func TestSCPProtocolMessages(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") writeErr := fmt.Errorf("test write error") mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: writeErr, } connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{}), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-t", "/tmp"}, }, } _, err := scpCommand.readProtocolMessage() assert.EqualError(t, err, readErr.Error()) err = scpCommand.sendConfirmationMessage() assert.EqualError(t, err, writeErr.Error()) err = scpCommand.sendProtocolMessage("E\n") assert.EqualError(t, err, writeErr.Error()) _, err = scpCommand.getNextUploadProtocolMessage() assert.EqualError(t, err, readErr.Error()) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer([]byte("T1183832947 0 1183833773 0\n")), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: writeErr, } scpCommand.connection.channel = &mockSSHChannel _, err = scpCommand.getNextUploadProtocolMessage() assert.EqualError(t, err, writeErr.Error()) respBuffer := []byte{0x02} protocolErrorMsg := "protocol error msg" respBuffer = append(respBuffer, protocolErrorMsg...) respBuffer = append(respBuffer, 0x0A) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(respBuffer), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: nil, } scpCommand.connection.channel = &mockSSHChannel err = scpCommand.readConfirmationMessage() if assert.Error(t, err) { assert.Equal(t, protocolErrorMsg, err.Error()) } mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(respBuffer), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: writeErr, } scpCommand.connection.channel = &mockSSHChannel err = scpCommand.downloadDirs(nil, nil) assert.ErrorIs(t, err, writeErr) } func TestSCPTestDownloadProtocolMessages(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") writeErr := fmt.Errorf("test write error") mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: writeErr, } connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{}), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-f", "-p", "/tmp"}, }, } path := "testDir" err := os.Mkdir(path, os.ModePerm) assert.NoError(t, err) stat, err := os.Stat(path) assert.NoError(t, err) err = scpCommand.sendDownloadProtocolMessages(path, stat) assert.EqualError(t, err, writeErr.Error()) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: nil, } err = scpCommand.sendDownloadProtocolMessages(path, stat) assert.EqualError(t, err, readErr.Error()) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: writeErr, } scpCommand.args = []string{"-f", "/tmp"} scpCommand.connection.channel = &mockSSHChannel err = scpCommand.sendDownloadProtocolMessages(path, stat) assert.EqualError(t, err, writeErr.Error()) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: nil, } scpCommand.connection.channel = &mockSSHChannel err = scpCommand.sendDownloadProtocolMessages(path, stat) assert.EqualError(t, err, readErr.Error()) err = os.Remove(path) assert.NoError(t, err) } func TestSCPCommandHandleErrors(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") writeErr := fmt.Errorf("test write error") mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: writeErr, } server, client := net.Pipe() defer func() { err := server.Close() assert.NoError(t, err) }() defer func() { err := client.Close() assert.NoError(t, err) }() connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{}), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-f", "/tmp"}, }, } err := scpCommand.handle() assert.EqualError(t, err, readErr.Error()) scpCommand.args = []string{"-i", "/tmp"} err = scpCommand.handle() assert.Error(t, err, "invalid scp command must fail") } func TestSCPErrorsMockFs(t *testing.T) { errFake := errors.New("fake error") u := dataprovider.User{} u.Username = "test" u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} u.HomeDir = os.TempDir() buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), } server, client := net.Pipe() defer func() { err := server.Close() assert.NoError(t, err) }() defer func() { err := client.Close() assert.NoError(t, err) }() connection := &Connection{ channel: &mockSSHChannel, BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", u), } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-r", "-t", "/tmp"}, }, } testfile := filepath.Join(u.HomeDir, "testfile") err := os.WriteFile(testfile, []byte("test"), os.ModePerm) assert.NoError(t, err) fs := newMockOsFs(errFake, nil, true, "123", os.TempDir()) err = scpCommand.handleUploadFile(fs, testfile, testfile, 0, false, 4, "/testfile") assert.NoError(t, err) err = os.Remove(testfile) assert.NoError(t, err) } func TestSCPRecursiveDownloadErrors(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") writeErr := fmt.Errorf("test write error") mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: writeErr, } server, client := net.Pipe() defer func() { err := server.Close() assert.NoError(t, err) }() defer func() { err := client.Close() assert.NoError(t, err) }() fs := vfs.NewOsFs("123", os.TempDir(), "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: os.TempDir(), }, }), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-r", "-f", "/tmp"}, }, } path := "testDir" err := os.Mkdir(path, os.ModePerm) assert.NoError(t, err) stat, err := os.Stat(path) assert.NoError(t, err) err = scpCommand.handleRecursiveDownload(fs, "invalid_dir", "invalid_dir", stat) assert.EqualError(t, err, writeErr.Error()) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: nil, } scpCommand.connection.channel = &mockSSHChannel err = scpCommand.handleRecursiveDownload(fs, "invalid_dir", "invalid_dir", stat) assert.Error(t, err, "recursive upload download must fail for a non existing dir") err = os.Remove(path) assert.NoError(t, err) } func TestSCPRecursiveUploadErrors(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") writeErr := fmt.Errorf("test write error") mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: writeErr, } connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{}), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-r", "-t", "/tmp"}, }, } err := scpCommand.handleRecursiveUpload() assert.Error(t, err, "recursive upload must fail, we send a fake error message") mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: nil, } scpCommand.connection.channel = &mockSSHChannel err = scpCommand.handleRecursiveUpload() assert.Error(t, err, "recursive upload must fail, we send a fake error message") } func TestSCPCreateDirs(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) u := dataprovider.User{} u.HomeDir = "home_rel_path" u.Username = "test" u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: nil, } fs, err := u.GetFilesystem("123") assert.NoError(t, err) connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", u), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-r", "-t", "/tmp"}, }, } err = scpCommand.handleCreateDir(fs, "invalid_dir") assert.Error(t, err, "create invalid dir must fail") } func TestSCPDownloadFileData(t *testing.T) { testfile := "testfile" buf := make([]byte, 65535) readErr := fmt.Errorf("test read error") writeErr := fmt.Errorf("test write error") stdErrBuf := make([]byte, 65535) mockSSHChannelReadErr := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: nil, } mockSSHChannelWriteErr := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: writeErr, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}}), channel: &mockSSHChannelReadErr, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-r", "-f", "/tmp"}, }, } err := os.WriteFile(testfile, []byte("test"), os.ModePerm) assert.NoError(t, err) stat, err := os.Stat(testfile) assert.NoError(t, err) err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) assert.EqualError(t, err, readErr.Error()) scpCommand.connection.channel = &mockSSHChannelWriteErr err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) assert.EqualError(t, err, writeErr.Error()) scpCommand.args = []string{"-r", "-p", "-f", "/tmp"} err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) assert.EqualError(t, err, writeErr.Error()) scpCommand.connection.channel = &mockSSHChannelReadErr err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) assert.EqualError(t, err, readErr.Error()) err = os.Remove(testfile) assert.NoError(t, err) } func TestSCPUploadFiledata(t *testing.T) { testfile := "testfile" buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") writeErr := fmt.Errorf("test write error") mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: writeErr, } user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "testuser", }, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", user), channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, args: []string{"-r", "-t", "/tmp"}, }, } file, err := os.Create(testfile) assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(), "/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) err = scpCommand.getUploadFileData(2, transfer) assert.Error(t, err, "upload must fail, we send a fake write error message") mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: readErr, WriteError: nil, } scpCommand.connection.channel = &mockSSHChannel file, err = os.Create(testfile) assert.NoError(t, err) transfer.File = file transfer.isFinished = false transfer.Connection.AddTransfer(transfer) err = scpCommand.getUploadFileData(2, transfer) assert.Error(t, err, "upload must fail, we send a fake read error message") respBuffer := []byte("12") respBuffer = append(respBuffer, 0x02) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(respBuffer), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: nil, } scpCommand.connection.channel = &mockSSHChannel file, err = os.Create(testfile) assert.NoError(t, err) baseTransfer.File = file transfer = newTransfer(baseTransfer, nil, nil, nil) transfer.Connection.AddTransfer(transfer) err = scpCommand.getUploadFileData(2, transfer) assert.Error(t, err, "upload must fail, we have not enough data to read") // the file is already closed so we have an error on trasfer closing mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: nil, } transfer.Connection.AddTransfer(transfer) err = scpCommand.getUploadFileData(0, transfer) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrTransferClosed.Error()) } transfer.Connection.RemoveTransfer(transfer) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: nil, } transfer.Connection.AddTransfer(transfer) err = scpCommand.getUploadFileData(2, transfer) assert.ErrorContains(t, err, os.ErrClosed.Error()) transfer.Connection.RemoveTransfer(transfer) err = os.Remove(testfile) assert.NoError(t, err) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestUploadError(t *testing.T) { common.Config.UploadMode = common.UploadModeAtomic user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "testuser", }, } fs := vfs.NewOsFs("", os.TempDir(), "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", user), } testfile := "testfile" fileTempName := "temptestfile" file, err := os.Create(fileTempName) assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(), testfile, common.TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) errFake := errors.New("fake error") transfer.TransferError(errFake) err = transfer.Close() if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } if assert.Error(t, transfer.ErrTransfer) { assert.EqualError(t, transfer.ErrTransfer, errFake.Error()) } assert.Equal(t, int64(0), transfer.BytesReceived.Load()) assert.NoFileExists(t, testfile) assert.NoFileExists(t, fileTempName) common.Config.UploadMode = common.UploadModeAtomicWithResume } func TestTransferFailingReader(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "testuser", HomeDir: os.TempDir(), }, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret("crypt secret"), }, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := newMockOsFs(nil, nil, true, "", os.TempDir()) connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", user), } request := sftp.NewRequest("Open", "afile.txt") request.Flags = 27 // read,write,create,truncate transfer, err := connection.handleFilewrite(request) require.NoError(t, err) buf := make([]byte, 32) _, err = transfer.ReadAt(buf, 0) assert.ErrorIs(t, err, sftp.ErrSSHFxOpUnsupported) if c, ok := transfer.(io.Closer); ok { err = c.Close() assert.NoError(t, err) } fsPath := filepath.Join(os.TempDir(), "afile.txt") r, _, err := pipeat.Pipe() assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) errRead := errors.New("read is not allowed") tr := newTransfer(baseTransfer, nil, vfs.NewPipeReader(r), errRead) _, err = tr.ReadAt(buf, 0) assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) err = tr.Close() assert.NoError(t, err) tr = newTransfer(baseTransfer, nil, nil, errRead) _, err = tr.ReadAt(buf, 0) assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) err = tr.Close() assert.NoError(t, err) err = os.Remove(fsPath) assert.NoError(t, err) assert.Len(t, connection.GetTransfers(), 0) } func TestConfigsFromProvider(t *testing.T) { err := dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) c := Configuration{} err = c.loadFromProvider() assert.NoError(t, err) assert.Len(t, c.HostKeyAlgorithms, 0) assert.Len(t, c.KexAlgorithms, 0) assert.Len(t, c.Ciphers, 0) assert.Len(t, c.MACs, 0) assert.Len(t, c.PublicKeyAlgorithms, 0) configs := dataprovider.Configs{ SFTPD: &dataprovider.SFTPDConfigs{ HostKeyAlgos: []string{ssh.KeyAlgoRSA}, KexAlgorithms: []string{ssh.InsecureKeyExchangeDHGEXSHA1}, Ciphers: []string{ssh.InsecureCipherAES128CBC}, MACs: []string{ssh.HMACSHA512ETM}, PublicKeyAlgos: []string{ssh.InsecureKeyAlgoDSA}, //nolint:staticcheck }, } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) err = c.loadFromProvider() assert.NoError(t, err) expectedHostKeyAlgos := append(preferredHostKeyAlgos, configs.SFTPD.HostKeyAlgos...) expectedKEXs := append(preferredKexAlgos, configs.SFTPD.KexAlgorithms...) expectedCiphers := append(preferredCiphers, configs.SFTPD.Ciphers...) expectedMACs := append(preferredMACs, configs.SFTPD.MACs...) expectedPublicKeyAlgos := append(preferredPublicKeyAlgos, configs.SFTPD.PublicKeyAlgos...) assert.Equal(t, expectedHostKeyAlgos, c.HostKeyAlgorithms) assert.Equal(t, expectedKEXs, c.KexAlgorithms) assert.Equal(t, expectedCiphers, c.Ciphers) assert.Equal(t, expectedMACs, c.MACs) assert.Equal(t, expectedPublicKeyAlgos, c.PublicKeyAlgorithms) err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestSupportedSecurityOptions(t *testing.T) { c := Configuration{ KexAlgorithms: supportedKexAlgos, MACs: supportedMACs, Ciphers: supportedCiphers, } var defaultKexs []string for _, k := range supportedKexAlgos { defaultKexs = append(defaultKexs, k) if k == ssh.KeyExchangeCurve25519 { defaultKexs = append(defaultKexs, keyExchangeCurve25519SHA256LibSSH) } } serverConfig := &ssh.ServerConfig{} err := c.configureSecurityOptions(serverConfig) assert.NoError(t, err) assert.Equal(t, supportedCiphers, serverConfig.Ciphers) assert.Equal(t, supportedMACs, serverConfig.MACs) assert.Equal(t, defaultKexs, serverConfig.KeyExchanges) c.KexAlgorithms = append(c.KexAlgorithms, "not a kex") err = c.configureSecurityOptions(serverConfig) assert.Error(t, err) c.KexAlgorithms = append(supportedKexAlgos, "diffie-hellman-group18-sha512") c.MACs = []string{ " hmac-sha2-256-etm@openssh.com ", " hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512 ", "hmac-sha1 ", " hmac-sha1-96", } err = c.configureSecurityOptions(serverConfig) assert.NoError(t, err) assert.Equal(t, supportedCiphers, serverConfig.Ciphers) assert.Equal(t, supportedMACs, serverConfig.MACs) assert.Equal(t, defaultKexs, serverConfig.KeyExchanges) } func TestLoadHostKeys(t *testing.T) { serverConfig := &ssh.ServerConfig{} c := Configuration{} c.HostKeys = []string{".", "missing file"} err := c.checkAndLoadHostKeys(configDir, serverConfig) assert.Error(t, err) testfile := filepath.Join(os.TempDir(), "invalidkey") err = os.WriteFile(testfile, []byte("some bytes"), os.ModePerm) assert.NoError(t, err) c.HostKeys = []string{testfile} err = c.checkAndLoadHostKeys(configDir, serverConfig) assert.Error(t, err) err = os.Remove(testfile) assert.NoError(t, err) keysDir := filepath.Join(os.TempDir(), "keys") err = os.MkdirAll(keysDir, os.ModePerm) assert.NoError(t, err) rsaKeyName := filepath.Join(keysDir, defaultPrivateRSAKeyName) ecdsaKeyName := filepath.Join(keysDir, defaultPrivateECDSAKeyName) ed25519KeyName := filepath.Join(keysDir, defaultPrivateEd25519KeyName) nonDefaultKeyName := filepath.Join(keysDir, "akey") c.HostKeys = []string{nonDefaultKeyName, rsaKeyName, ecdsaKeyName, ed25519KeyName} err = c.checkAndLoadHostKeys(configDir, serverConfig) assert.Error(t, err) c.HostKeyAlgorithms = []string{ssh.KeyAlgoRSASHA256} c.HostKeys = []string{ecdsaKeyName} err = c.checkAndLoadHostKeys(configDir, serverConfig) assert.Error(t, err) c.HostKeyAlgorithms = preferredHostKeyAlgos err = c.checkAndLoadHostKeys(configDir, serverConfig) assert.NoError(t, err) assert.FileExists(t, rsaKeyName) assert.FileExists(t, ecdsaKeyName) assert.FileExists(t, ed25519KeyName) assert.NoFileExists(t, nonDefaultKeyName) err = os.Remove(rsaKeyName) assert.NoError(t, err) err = os.Remove(ecdsaKeyName) assert.NoError(t, err) err = os.Remove(ed25519KeyName) assert.NoError(t, err) if runtime.GOOS != osWindows { err = os.Chmod(keysDir, 0551) assert.NoError(t, err) c.HostKeys = nil err = c.checkAndLoadHostKeys(keysDir, serverConfig) assert.Error(t, err) c.HostKeys = []string{rsaKeyName, ecdsaKeyName} err = c.checkAndLoadHostKeys(configDir, serverConfig) assert.Error(t, err) c.HostKeys = []string{ecdsaKeyName, rsaKeyName} err = c.checkAndLoadHostKeys(configDir, serverConfig) assert.Error(t, err) c.HostKeys = []string{ed25519KeyName} err = c.checkAndLoadHostKeys(configDir, serverConfig) assert.Error(t, err) err = os.Chmod(keysDir, 0755) assert.NoError(t, err) } err = os.RemoveAll(keysDir) assert.NoError(t, err) } func TestCertCheckerInitErrors(t *testing.T) { c := Configuration{} c.TrustedUserCAKeys = []string{".", "missing file"} err := c.initializeCertChecker("") assert.Error(t, err) testfile := filepath.Join(os.TempDir(), "invalidkey") err = os.WriteFile(testfile, []byte("some bytes"), os.ModePerm) assert.NoError(t, err) c.TrustedUserCAKeys = []string{testfile} err = c.initializeCertChecker("") assert.Error(t, err) err = os.Remove(testfile) assert.NoError(t, err) } func TestRecoverer(t *testing.T) { c := Configuration{} c.AcceptInboundConnection(nil, nil) connID := "connectionID" connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, "", "", dataprovider.User{}), } c.handleSftpConnection(nil, connection) sshCmd := sshCommand{ command: "cd", connection: connection, } err := sshCmd.handle() assert.EqualError(t, err, common.ErrGenericFailure.Error()) scpCmd := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: connection, }, } err = scpCmd.handle() assert.EqualError(t, err, common.ErrGenericFailure.Error()) assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestListernerAcceptErrors(t *testing.T) { errFake := errors.New("a fake error") listener := newFakeListener(errFake) c := Configuration{} err := c.serve(listener, nil) require.EqualError(t, err, errFake.Error()) err = listener.Close() require.NoError(t, err) errNetFake := &fakeNetError{error: errFake} listener = newFakeListener(errNetFake) err = c.serve(listener, nil) require.EqualError(t, err, errFake.Error()) err = listener.Close() require.NoError(t, err) } type fakeNetError struct { error count int } func (e *fakeNetError) Timeout() bool { return false } func (e *fakeNetError) Temporary() bool { e.count++ return e.count < 10 } func (e *fakeNetError) Error() string { return e.error.Error() } type fakeListener struct { server net.Conn client net.Conn err error } func (l *fakeListener) Accept() (net.Conn, error) { return l.client, l.err } func (l *fakeListener) Close() error { errClient := l.client.Close() errServer := l.server.Close() if errServer != nil { return errServer } return errClient } func (l *fakeListener) Addr() net.Addr { return l.server.LocalAddr() } func newFakeListener(err error) net.Listener { server, client := net.Pipe() return &fakeListener{ server: server, client: client, err: err, } } func TestLoadRevokedUserCertsFile(t *testing.T) { r := revokedCertificates{ certs: map[string]bool{}, } err := r.load() assert.NoError(t, err) r.filePath = filepath.Join(os.TempDir(), "sub", "testrevoked") err = os.MkdirAll(filepath.Dir(r.filePath), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(r.filePath, []byte(`no json`), 0644) assert.NoError(t, err) err = r.load() assert.Error(t, err) r.filePath = filepath.Dir(r.filePath) err = r.load() assert.Error(t, err) err = os.RemoveAll(r.filePath) assert.NoError(t, err) } func TestMaxUserSessions(t *testing.T) { connection := &Connection{ BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user_max_sessions", HomeDir: filepath.Clean(os.TempDir()), MaxSessions: 1, }, }), } err := common.Connections.Add(connection) assert.NoError(t, err) c := Configuration{} c.handleSftpConnection(nil, connection) buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), } conn := &Connection{ BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user_max_sessions", HomeDir: filepath.Clean(os.TempDir()), MaxSessions: 1, }, }), channel: &mockSSHChannel, } sshCmd := sshCommand{ command: "cd", connection: conn, } err = sshCmd.handle() if assert.Error(t, err) { assert.Contains(t, err.Error(), "too many open sessions") } scpCmd := scpCommand{ sshCommand: sshCommand{ command: "scp", connection: conn, }, } err = scpCmd.handle() if assert.Error(t, err) { assert.Contains(t, err.Error(), "too many open sessions") } common.Connections.Remove(connection.GetID()) assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestCanReadSymlink(t *testing.T) { connection := &Connection{ BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "user_can_read_symlink", HomeDir: filepath.Clean(os.TempDir()), Permissions: map[string][]string{ "/": {dataprovider.PermAny}, "/sub": {dataprovider.PermUpload}, }, }, Filters: dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ FilePatterns: []sdk.PatternsFilter{ { Path: "/denied", DeniedPatterns: []string{"*.txt"}, DenyPolicy: sdk.DenyPolicyHide, }, }, }, }, }), } err := connection.canReadLink("/sub/link") assert.ErrorIs(t, err, sftp.ErrSSHFxPermissionDenied) err = connection.canReadLink("/denied/file.txt") assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile) } func TestAuthenticationErrors(t *testing.T) { sftpAuthError := newAuthenticationError(nil, "", "") loginMethod := dataprovider.SSHLoginMethodPassword username := "test user" err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")), loginMethod, username) assert.ErrorIs(t, err, sftpAuthError) assert.ErrorIs(t, err, util.ErrNotFound) var sftpAuthErr *authenticationError if assert.ErrorAs(t, err, &sftpAuthErr) { assert.Equal(t, loginMethod, sftpAuthErr.getLoginMethod()) assert.Equal(t, username, sftpAuthErr.getUsername()) } err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod, username) assert.ErrorIs(t, err, sftpAuthError) assert.NotErrorIs(t, err, util.ErrNotFound) err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod, username) assert.ErrorIs(t, err, sftpAuthError) assert.NotErrorIs(t, err, util.ErrNotFound) err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod, username) assert.ErrorIs(t, err, sftpAuthError) assert.NotErrorIs(t, err, util.ErrNotFound) err = newAuthenticationError(nil, loginMethod, username) assert.ErrorIs(t, err, sftpAuthError) assert.NotErrorIs(t, err, util.ErrNotFound) } type mockCommandExecutor struct { Output []byte Err error } func (f mockCommandExecutor) CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) { return f.Output, f.Err } func TestVerifyWithOPKSSH(t *testing.T) { sshCert := []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg4+hKHVPKv183MU/Q7XD/mzDBFSc2YY3eraltxLMGJo0AAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAACAAAAFGhvc3QuZXhhbXBsZS5jb20ta2V5AAAAFAAAABBob3N0LmV4YW1wbGUuY29tAAAAAGXEzYAAAAAAd8sP4wAAAAAAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEUAAAADHJzYS1zaGEyLTI1NgAAAQA/ByIegNZYJRRl413S/8LxGvTZnbxsPwaluoJ/54niGZV9P28THz7d9jXfSHPjalhH93jNPfTYXvI4opnDC37ua1Nu8KKfk40IWXnnDdZLWraUxEidIzhmfVtz8kGdGoFQ8H0EzubL7zKNOTlfSfOoDlmQVOuxT/+eh2mEp4ri0/+8J1mLfLBr8tREX0/iaNjK+RKdcyTMicKursAYMCDdu8vlaphxea+ocyHM9izSX/l33t44V13ueTqIOh2Zbl2UE2k+jk+0dc1CmV0SEoiWiIyt8TRM4yQry1vPlQLsrf28sYM/QMwnhCVhyZO3vs5F25aQWrB9d51VEzBW9/fd host.example.com`) key, _, _, _, err := ssh.ParseAuthorizedKey(sshCert) //nolint:dogsled require.NoError(t, err) cert, ok := key.(*ssh.Certificate) require.True(t, ok) c := Configuration{} c.executor = mockCommandExecutor{ Err: errors.New("test error"), } err = c.verifyWithOPKSSH("user", cert) assert.Error(t, err) c.executor = mockCommandExecutor{} err = c.verifyWithOPKSSH("", cert) assert.Error(t, err) c.executor = mockCommandExecutor{ Output: ssh.MarshalAuthorizedKey(cert), } err = c.verifyWithOPKSSH("", cert) assert.Error(t, err) c.executor = mockCommandExecutor{ Output: ssh.MarshalAuthorizedKey(cert.SignatureKey), } err = c.verifyWithOPKSSH("", cert) assert.NoError(t, err) } ================================================ FILE: internal/sftpd/lister.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd import ( "io" "os" ) type listerAt []os.FileInfo // ListAt returns the number of entries copied and an io.EOF error if we made it to the end of the file list. // Take a look at the pkg/sftp godoc for more information about how this function should work. func (l listerAt) ListAt(f []os.FileInfo, offset int64) (int, error) { if offset >= int64(len(l)) { return 0, io.EOF } n := copy(f, l[offset:]) if n < len(f) { return n, io.EOF } return n, nil } ================================================ FILE: internal/sftpd/scp.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd import ( "errors" "fmt" "io" "math" "os" "path" "path/filepath" "runtime/debug" "strconv" "strings" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/vfs" ) var ( okMsg = []byte{0x00} warnMsg = []byte{0x01} // must be followed by an optional message and a newline errMsg = []byte{0x02} // must be followed by an optional message and a newline newLine = []byte{0x0A} ) type scpCommand struct { sshCommand } func (c *scpCommand) handle() (err error) { defer func() { if r := recover(); r != nil { logger.Error(logSender, "", "panic in handle scp command: %q stack trace: %v", r, string(debug.Stack())) err = common.ErrGenericFailure } }() if err := common.Connections.Add(c.connection); err != nil { defer c.connection.CloseFS() //nolint:errcheck logger.Info(logSender, "", "unable to add SCP connection: %v", err) return c.sendErrorResponse(err) } defer common.Connections.Remove(c.connection.GetID()) destPath := c.getDestPath() c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %s, dest path: %q", c.args, c.connection.User.Username, destPath) if c.hasFlag("t") { // -t means "to", so upload err = c.sendConfirmationMessage() if err != nil { return err } err = c.handleRecursiveUpload() if err != nil { return err } } else if c.hasFlag("f") { // -f means "from" so download err = c.readConfirmationMessage() if err != nil { return err } err = c.handleDownload(destPath) if err != nil { return err } } else { err = fmt.Errorf("scp command not supported, args: %v", c.args) c.connection.Log(logger.LevelDebug, "unsupported scp command, args: %v", c.args) } c.sendExitStatus(err) return err } func (c *scpCommand) handleRecursiveUpload() error { numDirs := 0 destPath := c.getDestPath() for { fs, err := c.connection.User.GetFilesystemForPath(destPath, c.connection.ID) if err != nil { c.connection.Log(logger.LevelError, "error uploading file %q: %+v", destPath, err) c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %q", destPath)) return err } command, err := c.getNextUploadProtocolMessage() if err != nil { if errors.Is(err, io.EOF) { return nil } c.sendErrorMessage(fs, err) return err } if strings.HasPrefix(command, "E") { numDirs-- c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs) if numDirs < 0 { err = errors.New("unacceptable end dir command") c.sendErrorMessage(nil, err) return err } // the destination dir is now the parent directory destPath = path.Join(destPath, "..") } else { sizeToRead, name, err := c.parseUploadMessage(fs, command) if err != nil { return err } if strings.HasPrefix(command, "D") { numDirs++ destPath = path.Join(destPath, name) fs, err = c.connection.User.GetFilesystemForPath(destPath, c.connection.ID) if err != nil { c.connection.Log(logger.LevelError, "error uploading file %q: %+v", destPath, err) c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %q", destPath)) return err } err = c.handleCreateDir(fs, destPath) if err != nil { return err } c.connection.Log(logger.LevelDebug, "received start dir command, num dirs: %v destPath: %q", numDirs, destPath) } else if strings.HasPrefix(command, "C") { err = c.handleUpload(c.getFileUploadDestPath(fs, destPath, name), sizeToRead) if err != nil { return err } } } err = c.sendConfirmationMessage() if err != nil { return err } } } func (c *scpCommand) handleCreateDir(fs vfs.Fs, dirPath string) error { c.connection.UpdateLastActivity() p, err := fs.ResolvePath(dirPath) if err != nil { c.connection.Log(logger.LevelError, "error creating dir: %q, invalid file path, err: %v", dirPath, err) c.sendErrorMessage(fs, err) return err } if !c.connection.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(dirPath)) { c.connection.Log(logger.LevelError, "error creating dir: %q, permission denied", dirPath) c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } info, err := c.connection.DoStat(dirPath, 1, true) if err == nil && info.IsDir() { return nil } err = c.createDir(fs, p) if err != nil { return err } c.connection.Log(logger.LevelDebug, "created dir %q", dirPath) return nil } // we need to close the transfer if we have an error func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) error { err := c.sendConfirmationMessage() if err != nil { transfer.TransferError(err) transfer.Close() return err } if sizeToRead > 0 { // we could replace this method with io.CopyN implementing "Write" method in transfer struct remaining := sizeToRead buf := make([]byte, int64(math.Min(32768, float64(sizeToRead)))) for { n, err := c.connection.channel.Read(buf) if err != nil { transfer.TransferError(err) transfer.Close() c.sendErrorMessage(transfer.Fs, err) return err } _, err = transfer.WriteAt(buf[:n], sizeToRead-remaining) if err != nil { transfer.Close() c.sendErrorMessage(transfer.Fs, err) return err } remaining -= int64(n) if remaining <= 0 { break } if remaining < int64(len(buf)) { buf = make([]byte, remaining) } } } err = c.readConfirmationMessage() if err != nil { transfer.TransferError(err) transfer.Close() return err } err = transfer.Close() if err != nil { c.sendErrorMessage(transfer.Fs, err) return err } return nil } func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil { err := fmt.Errorf("denying file write due to transfer count limits") c.connection.Log(logger.LevelInfo, "denying file write due to transfer count limits") c.sendErrorMessage(nil, err) return err } diskQuota, transferQuota := c.connection.HasSpace(isNewFile, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { err := fmt.Errorf("denying file write due to quota limits") c.connection.Log(logger.LevelError, "error uploading file: %q, err: %v", filePath, err) c.sendErrorMessage(nil, err) return err } _, err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, fileSize, os.O_TRUNC) if err != nil { c.connection.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) err = c.connection.GetPermissionDeniedError() c.sendErrorMessage(fs, err) return err } maxWriteSize, _ := c.connection.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.connection.GetCreateChecks(requestPath, isNewFile, false)) if err != nil { c.connection.Log(logger.LevelError, "error creating file %q: %v", resolvedPath, err) c.sendErrorMessage(fs, err) return err } initialSize := int64(0) truncatedSize := int64(0) // bytes truncated and not included in quota if !isNewFile { if vfs.HasTruncateSupport(fs) { vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateUserFolderQuota(&vfolder, &c.connection.User, 0, -fileSize, false) } else { dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck } } else { initialSize = fileSize truncatedSize = initialSize } if maxWriteSize > 0 { maxWriteSize += fileSize } } vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs, transferQuota) t := newTransfer(baseTransfer, w, nil, nil) return c.getUploadFileData(sizeToRead, t) } func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error { c.connection.UpdateLastActivity() fs, p, err := c.connection.GetFsAndResolvedPath(uploadFilePath) if err != nil { c.connection.Log(logger.LevelError, "error uploading file: %q, err: %v", uploadFilePath, err) c.sendErrorMessage(nil, err) return err } if ok, _ := c.connection.User.IsFileAllowed(uploadFilePath); !ok { c.connection.Log(logger.LevelWarn, "writing file %q is not allowed", uploadFilePath) c.sendErrorMessage(fs, c.connection.GetPermissionDeniedError()) return common.ErrPermissionDenied } filePath := p if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { filePath = fs.GetAtomicUploadPath(p) } stat, statErr := fs.Lstat(p) if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.connection.User.HasPerm(dataprovider.PermUpload, path.Dir(uploadFilePath)) { c.connection.Log(logger.LevelWarn, "cannot upload file: %q, permission denied", uploadFilePath) c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } return c.handleUploadFile(fs, p, filePath, sizeToRead, true, 0, uploadFilePath) } if statErr != nil { c.connection.Log(logger.LevelError, "error performing file stat %q: %v", p, statErr) c.sendErrorMessage(fs, statErr) return statErr } if stat.IsDir() { c.connection.Log(logger.LevelError, "attempted to open a directory for writing to: %q", p) err = fmt.Errorf("attempted to open a directory for writing: %q", p) c.sendErrorMessage(fs, err) return err } if !c.connection.User.HasPerm(dataprovider.PermOverwrite, uploadFilePath) { c.connection.Log(logger.LevelWarn, "cannot overwrite file: %q, permission denied", uploadFilePath) c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { _, _, err = fs.Rename(p, filePath, 0) if err != nil { c.connection.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %v", p, filePath, err) c.sendErrorMessage(fs, err) return err } } return c.handleUploadFile(fs, p, filePath, sizeToRead, false, stat.Size(), uploadFilePath) } func (c *scpCommand) sendDownloadProtocolMessages(virtualDirPath string, stat os.FileInfo) error { var err error if c.sendFileTime() { modTime := stat.ModTime().UnixNano() / 1000000000 tCommand := fmt.Sprintf("T%d 0 %d 0\n", modTime, modTime) err = c.sendProtocolMessage(tCommand) if err != nil { return err } err = c.readConfirmationMessage() if err != nil { return err } } dirName := path.Base(virtualDirPath) if dirName == "/" || dirName == "." { dirName = c.connection.User.Username } fileMode := fmt.Sprintf("D%v 0 %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), dirName) err = c.sendProtocolMessage(fileMode) if err != nil { return err } err = c.readConfirmationMessage() return err } // We send first all the files in the root directory and then the directories. // For each directory we recursively call this method again func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath string, stat os.FileInfo) error { var err error if c.isRecursive() { c.connection.Log(logger.LevelDebug, "recursive download, dir path %q virtual path %q", dirPath, virtualPath) err = c.sendDownloadProtocolMessages(virtualPath, stat) if err != nil { return err } // dirPath is a fs path, not a virtual path lister, err := fs.ReadDir(dirPath) if err != nil { c.sendErrorMessage(fs, err) return err } defer lister.Close() vdirs := c.connection.User.GetVirtualFoldersInfo(virtualPath) var dirs []string for { files, err := lister.Next(vfs.ListerBatchSize) finished := errors.Is(err, io.EOF) if err != nil && !finished { c.sendErrorMessage(fs, err) return err } files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath)) if len(vdirs) > 0 { files = append(files, vdirs...) vdirs = nil } for _, file := range files { filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name())) if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 { err = c.handleDownload(filePath) if err != nil { c.sendErrorMessage(fs, err) return err } } else if file.IsDir() { dirs = append(dirs, filePath) } } if finished { break } } lister.Close() return c.downloadDirs(fs, dirs) } err = errors.New("unable to send directory for non recursive copy") c.sendErrorMessage(nil, err) return err } func (c *scpCommand) downloadDirs(fs vfs.Fs, dirs []string) error { for _, dir := range dirs { if err := c.handleDownload(dir); err != nil { c.sendErrorMessage(fs, err) return err } } if err := c.sendProtocolMessage("E\n"); err != nil { return err } return c.readConfirmationMessage() } func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error { var err error if c.sendFileTime() { modTime := stat.ModTime().UnixNano() / 1000000000 tCommand := fmt.Sprintf("T%d 0 %d 0\n", modTime, modTime) err = c.sendProtocolMessage(tCommand) if err != nil { return err } err = c.readConfirmationMessage() if err != nil { return err } } if vfs.IsCryptOsFs(fs) { stat = fs.(*vfs.CryptFs).ConvertFileInfo(stat) } fileSize := stat.Size() readed := int64(0) fileMode := fmt.Sprintf("C%v %v %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), fileSize, filepath.Base(filePath)) err = c.sendProtocolMessage(fileMode) if err != nil { return err } err = c.readConfirmationMessage() if err != nil { return err } // we could replace this method with io.CopyN implementing "Read" method in transfer struct buf := make([]byte, 32768) var n int for { n, err = transfer.ReadAt(buf, readed) if err == nil || err == io.EOF { if n > 0 { _, err = c.connection.channel.Write(buf[:n]) } } readed += int64(n) if err != nil { break } } if err != io.EOF { c.sendErrorMessage(fs, err) return err } err = c.sendConfirmationMessage() if err != nil { return err } err = c.readConfirmationMessage() return err } func (c *scpCommand) handleDownload(filePath string) error { c.connection.UpdateLastActivity() if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil { err := fmt.Errorf("denying file read due to transfer count limits") c.connection.Log(logger.LevelInfo, "denying file read due to transfer count limits") c.sendErrorMessage(nil, err) return err } transferQuota := c.connection.GetTransferQuota() if !transferQuota.HasDownloadSpace() { c.connection.Log(logger.LevelInfo, "denying file read due to quota limits") c.sendErrorMessage(nil, c.connection.GetReadQuotaExceededError()) return c.connection.GetReadQuotaExceededError() } var err error fs, p, err := c.connection.GetFsAndResolvedPath(filePath) if err != nil { c.connection.Log(logger.LevelError, "error downloading file %q: %+v", filePath, err) c.sendErrorMessage(nil, fmt.Errorf("unable to download file %q: %w", filePath, err)) return err } var stat os.FileInfo if stat, err = fs.Stat(p); err != nil { c.connection.Log(logger.LevelError, "error downloading file: %q->%q, err: %v", filePath, p, err) c.sendErrorMessage(fs, err) return err } if stat.IsDir() { if !c.connection.User.HasPerm(dataprovider.PermDownload, filePath) { c.connection.Log(logger.LevelWarn, "error downloading dir: %q, permission denied", filePath) c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } err = c.handleRecursiveDownload(fs, p, filePath, stat) return err } if !c.connection.User.HasPerm(dataprovider.PermDownload, path.Dir(filePath)) { c.connection.Log(logger.LevelWarn, "error downloading dir: %q, permission denied", filePath) c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } if ok, policy := c.connection.User.IsFileAllowed(filePath); !ok { c.connection.Log(logger.LevelWarn, "reading file %q is not allowed", filePath) c.sendErrorMessage(fs, c.connection.GetErrorForDeniedFile(policy)) return common.ErrPermissionDenied } if _, err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreDownload, p, filePath, 0, 0); err != nil { c.connection.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", filePath, err) c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } file, r, cancelFn, err := fs.Open(p, 0) if err != nil { c.connection.Log(logger.LevelError, "could not open file %q for reading: %v", p, err) c.sendErrorMessage(fs, err) return err } baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath, common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota) t := newTransfer(baseTransfer, nil, r, nil) err = c.sendDownloadFileData(fs, p, stat, t) // we need to call Close anyway and return close error if any and // if we have no previous error if err == nil { err = t.Close() } else { t.TransferError(err) t.Close() } return err } func (c *scpCommand) sendFileTime() bool { return c.hasFlag("p") } func (c *scpCommand) isRecursive() bool { return c.hasFlag("r") } func (c *scpCommand) hasFlag(flag string) bool { for idx := 0; idx < len(c.args)-1; idx++ { arg := c.args[idx] if !strings.HasPrefix(arg, "--") && strings.HasPrefix(arg, "-") && strings.Contains(arg, flag) { return true } } return false } // read the SCP confirmation message and the optional text message // the channel will be closed on errors func (c *scpCommand) readConfirmationMessage() error { var msg strings.Builder buf := make([]byte, 1) n, err := c.connection.channel.Read(buf) if err != nil { c.connection.channel.Close() return err } if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) { isError := buf[0] == errMsg[0] for { n, err = c.connection.channel.Read(buf) readed := buf[:n] if err != nil || (n == 1 && readed[0] == newLine[0]) { break } if n > 0 { msg.Write(readed) } } c.connection.Log(logger.LevelInfo, "scp error message received: %v is error: %v", msg.String(), isError) err = fmt.Errorf("%v", msg.String()) c.connection.channel.Close() } return err } // protool messages are newline terminated func (c *scpCommand) readProtocolMessage() (string, error) { var command strings.Builder var err error buf := make([]byte, 1) for { var n int n, err = c.connection.channel.Read(buf) if err != nil { break } if n > 0 { readed := buf[:n] if n == 1 && readed[0] == newLine[0] { break } command.Write(readed) } } if err != nil && !errors.Is(err, io.EOF) { c.connection.channel.Close() } return command.String(), err } // sendErrorMessage sends an error message and close the channel // we don't check write errors here, we have to close the channel anyway // //nolint:errcheck func (c *scpCommand) sendErrorMessage(fs vfs.Fs, err error) { c.connection.channel.Write(errMsg) if fs != nil { c.connection.channel.Write([]byte(c.connection.GetFsError(fs, err).Error())) } else { c.connection.channel.Write([]byte(err.Error())) } c.connection.channel.Write(newLine) c.connection.channel.Close() } // send scp confirmation message and close the channel if an error happen func (c *scpCommand) sendConfirmationMessage() error { _, err := c.connection.channel.Write(okMsg) if err != nil { c.connection.channel.Close() } return err } // sends a protocol message and close the channel on error func (c *scpCommand) sendProtocolMessage(message string) error { _, err := c.connection.channel.Write([]byte(message)) if err != nil { c.connection.Log(logger.LevelError, "error sending protocol message: %v, err: %v", message, err) c.connection.channel.Close() } return err } // get the next upload protocol message ignoring T command if any func (c *scpCommand) getNextUploadProtocolMessage() (string, error) { var command string var err error for { command, err = c.readProtocolMessage() if err != nil { return command, err } if strings.HasPrefix(command, "T") { err = c.sendConfirmationMessage() if err != nil { return command, err } } else { break } } return command, err } func (c *scpCommand) createDir(fs vfs.Fs, dirPath string) error { err := fs.Mkdir(dirPath) if err != nil { c.connection.Log(logger.LevelError, "error creating dir %q: %v", dirPath, err) c.sendErrorMessage(fs, err) return err } vfs.SetPathPermissions(fs, dirPath, c.connection.User.GetUID(), c.connection.User.GetGID()) return err } // parse protocol messages such as: // D0755 0 testdir // or: // C0644 6 testfile // and returns file size and file/directory name func (c *scpCommand) parseUploadMessage(fs vfs.Fs, command string) (int64, string, error) { var size int64 var name string var err error if !strings.HasPrefix(command, "C") && !strings.HasPrefix(command, "D") { err = fmt.Errorf("unknown or invalid upload message: %v args: %v user: %v", command, c.args, c.connection.User.Username) c.connection.Log(logger.LevelError, "error: %v", err) c.sendErrorMessage(fs, err) return size, name, err } parts := strings.SplitN(command, " ", 3) if len(parts) == 3 { size, err = strconv.ParseInt(parts[1], 10, 64) if err != nil { c.connection.Log(logger.LevelError, "error getting size from upload message: %v", err) c.sendErrorMessage(fs, err) return size, name, err } name = parts[2] if name == "" { err = fmt.Errorf("error getting name from upload message, cannot be empty") c.connection.Log(logger.LevelError, "error: %v", err) c.sendErrorMessage(fs, err) return size, name, err } } else { err = fmt.Errorf("unable to split upload message: %q", command) c.connection.Log(logger.LevelError, "error: %v", err) c.sendErrorMessage(fs, err) return size, name, err } return size, name, err } func (c *scpCommand) getFileUploadDestPath(fs vfs.Fs, scpDestPath, fileName string) string { if !c.isRecursive() { // if the upload is not recursive and the destination path does not end with "/" // then scpDestPath is the wanted filename, for example: // scp fileName.txt user@127.0.0.1:/newFileName.txt // or // scp fileName.txt user@127.0.0.1:/fileName.txt if !strings.HasSuffix(scpDestPath, "/") { // but if scpDestPath is an existing directory then we put the uploaded file // inside that directory this is as scp command works, for example: // scp fileName.txt user@127.0.0.1:/existing_dir if p, err := fs.ResolvePath(scpDestPath); err == nil { if stat, err := fs.Stat(p); err == nil { if stat.IsDir() { return path.Join(scpDestPath, fileName) } } } return scpDestPath } } // if the upload is recursive or scpDestPath has the "/" suffix then the destination // file is relative to scpDestPath return path.Join(scpDestPath, fileName) } func getFileModeAsString(fileMode os.FileMode, isDir bool) string { var defaultMode string if isDir { defaultMode = "0755" } else { defaultMode = "0644" } if fileMode == 0 { return defaultMode } modeString := []byte(fileMode.String()) nullPerm := []byte("-") u := 0 g := 0 o := 0 s := 0 lastChar := len(modeString) - 1 if fileMode&os.ModeSticky != 0 { s++ } if fileMode&os.ModeSetuid != 0 { s += 2 } if fileMode&os.ModeSetgid != 0 { s += 4 } if modeString[lastChar-8] != nullPerm[0] { u += 4 } if modeString[lastChar-7] != nullPerm[0] { u += 2 } if modeString[lastChar-6] != nullPerm[0] { u++ } if modeString[lastChar-5] != nullPerm[0] { g += 4 } if modeString[lastChar-4] != nullPerm[0] { g += 2 } if modeString[lastChar-3] != nullPerm[0] { g++ } if modeString[lastChar-2] != nullPerm[0] { o += 4 } if modeString[lastChar-1] != nullPerm[0] { o += 2 } if modeString[lastChar] != nullPerm[0] { o++ } return fmt.Sprintf("%v%v%v%v", s, u, g, o) } ================================================ FILE: internal/sftpd/server.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd import ( "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "io/fs" "maps" "net" "os" "os/exec" "path/filepath" "runtime/debug" "slices" "strings" "sync" "time" "github.com/pkg/sftp" "github.com/sftpgo/sdk/plugin/notifier" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( defaultPrivateRSAKeyName = "id_rsa" defaultPrivateECDSAKeyName = "id_ecdsa" defaultPrivateEd25519KeyName = "id_ed25519" sourceAddressCriticalOption = "source-address" keyExchangeCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org" extraDataPartialSuccessErrKey = "partialSuccessErr" extraDataUserKey = "user" extraDataKeyIDKey = "keyID" extraDataLoginMethodKey = "login_method" ) var ( supportedAlgos = ssh.SupportedAlgorithms() insecureAlgos = ssh.InsecureAlgorithms() sftpExtensions = []string{"statvfs@openssh.com"} supportedHostKeyAlgos = append(supportedAlgos.HostKeys, insecureAlgos.HostKeys...) preferredHostKeyAlgos = []string{ ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512, ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521, ssh.KeyAlgoED25519, } supportedPublicKeyAlgos = append(supportedAlgos.PublicKeyAuths, insecureAlgos.PublicKeyAuths...) preferredPublicKeyAlgos = supportedAlgos.PublicKeyAuths supportedKexAlgos = append(supportedAlgos.KeyExchanges, insecureAlgos.KeyExchanges...) preferredKexAlgos = supportedAlgos.KeyExchanges supportedCiphers = append(supportedAlgos.Ciphers, insecureAlgos.Ciphers...) preferredCiphers = supportedAlgos.Ciphers supportedMACs = append(supportedAlgos.MACs, insecureAlgos.MACs...) preferredMACs = []string{ ssh.HMACSHA256ETM, ssh.HMACSHA256, } revokedCertManager = revokedCertificates{ certs: map[string]bool{}, } ) type commandExecutor interface { CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) } type defaultExecutor struct{} func (d defaultExecutor) CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) { cmd := exec.CommandContext(ctx, name, args...) cmd.Env = []string{} return cmd.CombinedOutput() } // Binding defines the configuration for a network listener type Binding struct { // The address to listen on. A blank value means listen on all available network interfaces. Address string `json:"address" mapstructure:"address"` // The port used for serving requests Port int `json:"port" mapstructure:"port"` // Apply the proxy configuration, if any, for this binding ApplyProxyConfig bool `json:"apply_proxy_config" mapstructure:"apply_proxy_config"` } // GetAddress returns the binding address func (b *Binding) GetAddress() string { return fmt.Sprintf("%s:%d", b.Address, b.Port) } // IsValid returns true if the binding port is > 0 func (b *Binding) IsValid() bool { return b.Port > 0 } // HasProxy returns true if the proxy protocol is active for this binding func (b *Binding) HasProxy() bool { return b.ApplyProxyConfig && common.Config.ProxyProtocol > 0 } // Configuration for the SFTP server type Configuration struct { // Addresses and ports to bind to Bindings []Binding `json:"bindings" mapstructure:"bindings"` // Maximum number of authentication attempts permitted per connection. // If set to a negative number, the number of attempts is unlimited. // If set to zero, the number of attempts are limited to 6. MaxAuthTries int `json:"max_auth_tries" mapstructure:"max_auth_tries"` // HostKeys define the daemon's private host keys. // Each host key can be defined as a path relative to the configuration directory or an absolute one. // If empty or missing, the daemon will search or try to generate "id_rsa" and "id_ecdsa" host keys // inside the configuration directory. HostKeys []string `json:"host_keys" mapstructure:"host_keys"` // HostCertificates defines public host certificates. // Each certificate can be defined as a path relative to the configuration directory or an absolute one. // Certificate's public key must match a private host key otherwise it will be silently ignored. HostCertificates []string `json:"host_certificates" mapstructure:"host_certificates"` // HostKeyAlgorithms lists the public key algorithms that the server will accept for host // key authentication. HostKeyAlgorithms []string `json:"host_key_algorithms" mapstructure:"host_key_algorithms"` // KexAlgorithms specifies the available KEX (Key Exchange) algorithms in // preference order. KexAlgorithms []string `json:"kex_algorithms" mapstructure:"kex_algorithms"` // Ciphers specifies the ciphers allowed Ciphers []string `json:"ciphers" mapstructure:"ciphers"` // MACs Specifies the available MAC (message authentication code) algorithms // in preference order MACs []string `json:"macs" mapstructure:"macs"` // PublicKeyAlgorithms lists the supported public key algorithms for client authentication. PublicKeyAlgorithms []string `json:"public_key_algorithms" mapstructure:"public_key_algorithms"` // TrustedUserCAKeys specifies a list of public keys paths of certificate authorities // that are trusted to sign user certificates for authentication. // The paths can be absolute or relative to the configuration directory TrustedUserCAKeys []string `json:"trusted_user_ca_keys" mapstructure:"trusted_user_ca_keys"` // Path to a file containing the revoked user certificates. // This file must contain a JSON list with the public key fingerprints of the revoked certificates. // Example content: // ["SHA256:bsBRHC/xgiqBJdSuvSTNpJNLTISP/G356jNMCRYC5Es","SHA256:119+8cL/HH+NLMawRsJx6CzPF1I3xC+jpM60bQHXGE8"] RevokedUserCertsFile string `json:"revoked_user_certs_file" mapstructure:"revoked_user_certs_file"` // Absolute path to the opkssh binary used for OpenPubkey SSH integration OPKSSHPath string `json:"opkssh_path" mapstructure:"opkssh_path"` // Expected SHA256 checksum of the opkssh binary. It is verified at application startup OPKSSHChecksum string `json:"opkssh_checksum" mapstructure:"opkssh_checksum"` // LoginBannerFile the contents of the specified file, if any, are sent to // the remote user before authentication is allowed. LoginBannerFile string `json:"login_banner_file" mapstructure:"login_banner_file"` // List of enabled SSH commands. // We support the following SSH commands: // - "scp". SCP is an experimental feature, we have our own SCP implementation since // we can't rely on scp system command to proper handle permissions, quota and // user's home dir restrictions. // The SCP protocol is quite simple but there is no official docs about it, // so we need more testing and feedbacks before enabling it by default. // We may not handle some borderline cases or have sneaky bugs. // Please do accurate tests yourself before enabling SCP and let us known // if something does not work as expected for your use cases. // SCP between two remote hosts is supported using the `-3` scp option. // - "md5sum", "sha1sum", "sha256sum", "sha384sum", "sha512sum". Useful to check message // digests for uploaded files. These commands are implemented inside SFTPGo so they // work even if the matching system commands are not available, for example on Windows. // - "cd", "pwd". Some mobile SFTP clients does not support the SFTP SSH_FXP_REALPATH and so // they use "cd" and "pwd" SSH commands to get the initial directory. // Currently `cd` do nothing and `pwd` always returns the "/" path. // // The following SSH commands are enabled by default: "md5sum", "sha1sum", "cd", "pwd". // "*" enables all supported SSH commands. EnabledSSHCommands []string `json:"enabled_ssh_commands" mapstructure:"enabled_ssh_commands"` // KeyboardInteractiveAuthentication specifies whether keyboard interactive authentication is allowed. // If no keyboard interactive hook or auth plugin is defined the default is to prompt for the user password and then the // one time authentication code, if defined. KeyboardInteractiveAuthentication bool `json:"keyboard_interactive_authentication" mapstructure:"keyboard_interactive_authentication"` // Absolute path to an external program or an HTTP URL to invoke for keyboard interactive authentication. // Leave empty to disable this authentication mode. KeyboardInteractiveHook string `json:"keyboard_interactive_auth_hook" mapstructure:"keyboard_interactive_auth_hook"` // PasswordAuthentication specifies whether password authentication is allowed. PasswordAuthentication bool `json:"password_authentication" mapstructure:"password_authentication"` certChecker *ssh.CertChecker parsedUserCAKeys []ssh.PublicKey executor commandExecutor } type authenticationError struct { err error loginMethod string username string } func (e *authenticationError) Error() string { return fmt.Sprintf("Authentication error: %v", e.err) } // Is reports if target matches func (e *authenticationError) Is(target error) bool { _, ok := target.(*authenticationError) return ok } // Unwrap returns the wrapped error func (e *authenticationError) Unwrap() error { return e.err } func (e *authenticationError) getLoginMethod() string { return e.loginMethod } func (e *authenticationError) getUsername() string { return e.username } func newAuthenticationError(err error, loginMethod, username string) *authenticationError { return &authenticationError{err: err, loginMethod: loginMethod, username: username} } // ShouldBind returns true if there is at least a valid binding func (c *Configuration) ShouldBind() bool { for _, binding := range c.Bindings { if binding.IsValid() { return true } } return false } func (c *Configuration) getServerConfig() *ssh.ServerConfig { serverConfig := &ssh.ServerConfig{ NoClientAuth: false, MaxAuthTries: c.MaxAuthTries, PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { sp, err := c.validatePublicKeyCredentials(conn, pubKey) if err != nil { return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err), dataprovider.SSHLoginMethodPublicKey, conn.User()) } return sp, nil }, VerifiedPublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey, permissions *ssh.Permissions, signatureAlgorithm string) (*ssh.Permissions, error) { if partialErr, ok := permissions.ExtraData[extraDataPartialSuccessErrKey]; ok { logger.Info(logSender, hex.EncodeToString(conn.SessionID()), "user %q authenticated with partial success, signature algorithm %q", conn.User(), signatureAlgorithm) return nil, partialErr.(error) } method := dataprovider.SSHLoginMethodPublicKey user := permissions.ExtraData[extraDataUserKey].(dataprovider.User) keyID := permissions.ExtraData[extraDataKeyIDKey].(string) sshPerm, err := loginUser(&user, method, fmt.Sprintf("%s (%s)", keyID, signatureAlgorithm), conn) if err == nil { // if we have a SSH user cert we need to merge certificate permissions with our ones // we only set Extensions, so CriticalOptions are always the ones from the certificate sshPerm.CriticalOptions = permissions.CriticalOptions if permissions.Extensions != nil { if sshPerm.Extensions == nil { sshPerm.Extensions = make(map[string]string) } maps.Copy(sshPerm.Extensions, permissions.Extensions) } if sshPerm.ExtraData == nil { sshPerm.ExtraData = make(map[any]any) } } user.Username = conn.User() ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) updateLoginMetrics(&user, ipAddr, method, err) return sshPerm, err }, ServerVersion: fmt.Sprintf("SSH-2.0-%s", version.GetServerVersion("_", false)), } if c.PasswordAuthentication { serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { return c.validatePasswordCredentials(conn, password, dataprovider.LoginMethodPassword) } serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.LoginMethodPassword) } serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) return serverConfig } func (c *Configuration) updateSupportedAuthentications() { serviceStatus.Authentications = util.RemoveDuplicates(serviceStatus.Authentications, false) if slices.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) && slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) { serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndPassword) } if slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) && slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) { serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndKeyboardInt) } } func (c *Configuration) loadFromProvider() error { configs, err := dataprovider.GetConfigs() if err != nil { return fmt.Errorf("unable to load config from provider: %w", err) } configs.SetNilsToEmpty() if len(configs.SFTPD.HostKeyAlgos) > 0 { if len(c.HostKeyAlgorithms) == 0 { c.HostKeyAlgorithms = preferredHostKeyAlgos } c.HostKeyAlgorithms = append(c.HostKeyAlgorithms, configs.SFTPD.HostKeyAlgos...) } if len(configs.SFTPD.PublicKeyAlgos) > 0 { if len(c.PublicKeyAlgorithms) == 0 { c.PublicKeyAlgorithms = preferredPublicKeyAlgos } c.PublicKeyAlgorithms = append(c.PublicKeyAlgorithms, configs.SFTPD.PublicKeyAlgos...) } if len(configs.SFTPD.KexAlgorithms) > 0 { if len(c.KexAlgorithms) == 0 { c.KexAlgorithms = preferredKexAlgos } c.KexAlgorithms = append(c.KexAlgorithms, configs.SFTPD.KexAlgorithms...) } if len(configs.SFTPD.Ciphers) > 0 { if len(c.Ciphers) == 0 { c.Ciphers = preferredCiphers } c.Ciphers = append(c.Ciphers, configs.SFTPD.Ciphers...) } if len(configs.SFTPD.MACs) > 0 { if len(c.MACs) == 0 { c.MACs = preferredMACs } c.MACs = append(c.MACs, configs.SFTPD.MACs...) } return nil } // Initialize the SFTP server and add a persistent listener to handle inbound SFTP connections. func (c *Configuration) Initialize(configDir string) error { c.executor = defaultExecutor{} if err := c.loadFromProvider(); err != nil { return fmt.Errorf("unable to load configs from provider: %w", err) } serviceStatus = ServiceStatus{} serverConfig := c.getServerConfig() if !c.ShouldBind() { return common.ErrNoBinding } sftp.SetSFTPExtensions(sftpExtensions...) //nolint:errcheck // we configure valid SFTP Extensions so we cannot get an error sftp.MaxFilelist = 250 if err := c.configureSecurityOptions(serverConfig); err != nil { return err } if err := c.checkAndLoadHostKeys(configDir, serverConfig); err != nil { serviceStatus.HostKeys = nil return err } if err := c.initializeCertChecker(configDir); err != nil { return err } if err := c.initializeOPKSSH(); err != nil { return err } c.configureKeyboardInteractiveAuth(serverConfig) c.configureLoginBanner(serverConfig, configDir) c.checkSSHCommands() exitChannel := make(chan error, 1) serviceStatus.Bindings = nil for _, binding := range c.Bindings { if !binding.IsValid() { continue } serviceStatus.Bindings = append(serviceStatus.Bindings, binding) go func(binding Binding) { addr := binding.GetAddress() util.CheckTCP4Port(binding.Port) listener, err := net.Listen("tcp", addr) if err != nil { logger.Warn(logSender, "", "error starting listener on address %v: %v", addr, err) exitChannel <- err return } if binding.ApplyProxyConfig && common.Config.ProxyProtocol > 0 { proxyListener, err := common.Config.GetProxyListener(listener) if err != nil { logger.Warn(logSender, "", "error enabling proxy listener: %v", err) exitChannel <- err return } listener = proxyListener } exitChannel <- c.serve(listener, serverConfig) }(binding) } serviceStatus.IsActive = true serviceStatus.SSHCommands = c.EnabledSSHCommands c.updateSupportedAuthentications() return <-exitChannel } func (c *Configuration) serve(listener net.Listener, serverConfig *ssh.ServerConfig) error { logger.Info(logSender, "", "server listener registered, address: %s", listener.Addr().String()) var tempDelay time.Duration // how long to sleep on accept failure for { conn, err := listener.Accept() if err != nil { // see https://github.com/golang/go/blob/4aa1efed4853ea067d665a952eee77c52faac774/src/net/http/server.go#L3046 if ne, ok := err.(net.Error); ok && ne.Temporary() { //nolint:staticcheck if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 } if maxDelay := 1 * time.Second; tempDelay > maxDelay { tempDelay = maxDelay } logger.Warn(logSender, "", "accept error: %v; retrying in %v", err, tempDelay) time.Sleep(tempDelay) continue } logger.Warn(logSender, "", "unrecoverable accept error: %v", err) return err } tempDelay = 0 go c.AcceptInboundConnection(conn, serverConfig) } } func (c *Configuration) configureKeyAlgos(serverConfig *ssh.ServerConfig) error { if len(c.HostKeyAlgorithms) == 0 { c.HostKeyAlgorithms = preferredHostKeyAlgos } else { c.HostKeyAlgorithms = util.RemoveDuplicates(c.HostKeyAlgorithms, true) } for _, hostKeyAlgo := range c.HostKeyAlgorithms { if !slices.Contains(supportedHostKeyAlgos, hostKeyAlgo) { return fmt.Errorf("unsupported host key algorithm %q", hostKeyAlgo) } } if len(c.PublicKeyAlgorithms) > 0 { c.PublicKeyAlgorithms = util.RemoveDuplicates(c.PublicKeyAlgorithms, true) for _, algo := range c.PublicKeyAlgorithms { if !slices.Contains(supportedPublicKeyAlgos, algo) { return fmt.Errorf("unsupported public key authentication algorithm %q", algo) } } } else { c.PublicKeyAlgorithms = preferredPublicKeyAlgos } serverConfig.PublicKeyAuthAlgorithms = c.PublicKeyAlgorithms serviceStatus.PublicKeyAlgorithms = c.PublicKeyAlgorithms return nil } func (c *Configuration) checkKeyExchangeAlgorithms() { var kexs []string for _, k := range c.KexAlgorithms { if k == "diffie-hellman-group18-sha512" { logger.Warn(logSender, "", "KEX %q is not supported and will be ignored", k) continue } kexs = append(kexs, k) if strings.TrimSpace(k) == keyExchangeCurve25519SHA256LibSSH { kexs = append(kexs, ssh.KeyExchangeCurve25519) } if strings.TrimSpace(k) == ssh.KeyExchangeCurve25519 { kexs = append(kexs, keyExchangeCurve25519SHA256LibSSH) } } c.KexAlgorithms = util.RemoveDuplicates(kexs, true) } func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) error { if err := c.configureKeyAlgos(serverConfig); err != nil { return err } if len(c.KexAlgorithms) > 0 { c.checkKeyExchangeAlgorithms() for _, kex := range c.KexAlgorithms { if kex == keyExchangeCurve25519SHA256LibSSH { continue } if !slices.Contains(supportedKexAlgos, kex) { return fmt.Errorf("unsupported key-exchange algorithm %q", kex) } } } else { c.KexAlgorithms = preferredKexAlgos c.checkKeyExchangeAlgorithms() } serverConfig.KeyExchanges = c.KexAlgorithms serviceStatus.KexAlgorithms = c.KexAlgorithms if len(c.Ciphers) > 0 { c.Ciphers = util.RemoveDuplicates(c.Ciphers, true) for _, cipher := range c.Ciphers { if slices.Contains([]string{"aes192-cbc", "aes256-cbc"}, cipher) { continue } if !slices.Contains(supportedCiphers, cipher) { return fmt.Errorf("unsupported cipher %q", cipher) } } } else { c.Ciphers = preferredCiphers } serverConfig.Ciphers = c.Ciphers serviceStatus.Ciphers = c.Ciphers if len(c.MACs) > 0 { c.MACs = util.RemoveDuplicates(c.MACs, true) for _, mac := range c.MACs { if !slices.Contains(supportedMACs, mac) { return fmt.Errorf("unsupported MAC algorithm %q", mac) } } } else { c.MACs = preferredMACs } serverConfig.MACs = c.MACs serviceStatus.MACs = c.MACs return nil } func (c *Configuration) configureLoginBanner(serverConfig *ssh.ServerConfig, configDir string) { if c.LoginBannerFile != "" { bannerFilePath := c.LoginBannerFile if !filepath.IsAbs(bannerFilePath) { bannerFilePath = filepath.Join(configDir, bannerFilePath) } bannerContent, err := os.ReadFile(bannerFilePath) if err == nil { banner := util.BytesToString(bannerContent) serverConfig.BannerCallback = func(_ ssh.ConnMetadata) string { return banner } } else { logger.WarnToConsole("unable to read SFTPD login banner file: %v", err) logger.Warn(logSender, "", "unable to read login banner file: %v", err) } } } func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.ServerConfig) { if !c.KeyboardInteractiveAuthentication { return } if c.KeyboardInteractiveHook != "" { if !strings.HasPrefix(c.KeyboardInteractiveHook, "http") { if !filepath.IsAbs(c.KeyboardInteractiveHook) { c.KeyboardInteractiveAuthentication = false logger.WarnToConsole("invalid keyboard interactive authentication program: %q must be an absolute path", c.KeyboardInteractiveHook) logger.Warn(logSender, "", "invalid keyboard interactive authentication program: %q must be an absolute path", c.KeyboardInteractiveHook) return } _, err := os.Stat(c.KeyboardInteractiveHook) if err != nil { c.KeyboardInteractiveAuthentication = false logger.WarnToConsole("invalid keyboard interactive authentication program:: %v", err) logger.Warn(logSender, "", "invalid keyboard interactive authentication program:: %v", err) return } } } serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyboardInteractive, false) } serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) } // AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not. func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { //nolint:gocyclo defer func() { if r := recover(); r != nil { logger.Error(logSender, "", "panic in AcceptInboundConnection: %q stack trace: %v", r, string(debug.Stack())) } }() ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) common.Connections.AddClientConnection(ipAddr) defer common.Connections.RemoveClientConnection(ipAddr) if !canAcceptConnection(ipAddr) { conn.Close() return } // Before beginning a handshake must be performed on the incoming net.Conn // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck sconn, chans, reqs, err := ssh.NewServerConn(conn, config) if err != nil { logger.Debug(logSender, "", "failed to accept an incoming connection from ip %q: %v", ipAddr, err) checkAuthError(ipAddr, err) return } // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on conn.SetDeadline(time.Time{}) //nolint:errcheck go ssh.DiscardRequests(reqs) defer sconn.Close() user := sconn.Permissions.ExtraData[extraDataUserKey].(dataprovider.User) loginType := sconn.Permissions.ExtraData[extraDataLoginMethodKey].(string) connectionID := hex.EncodeToString(sconn.SessionID()) defer user.CloseFs() //nolint:errcheck if err = user.CheckFsRoot(connectionID); err != nil { logger.Warn(logSender, connectionID, "unable to check fs root for user %q: %v", user.Username, err) go discardAllChannels(chans, "invalid root fs", connectionID) return } logger.LoginLog(user.Username, ipAddr, loginType, common.ProtocolSSH, connectionID, util.BytesToString(sconn.ClientVersion()), true, fmt.Sprintf("negotiated algorithms: %+v", sconn.Conn.(ssh.AlgorithmsConnMetadata).Algorithms())) dataprovider.UpdateLastLogin(&user) sshConnection := common.NewSSHConnection(connectionID, sconn) common.Connections.AddSSHConnection(sshConnection) defer common.Connections.RemoveSSHConnection(connectionID) channelCounter := int64(0) for newChannel := range chans { // If its not a session channel we just move on because its not something we // know how to handle at this point. if newChannel.ChannelType() != "session" { logger.Log(logger.LevelDebug, common.ProtocolSSH, connectionID, "received an unknown channel type: %v", newChannel.ChannelType()) newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck continue } channel, requests, err := newChannel.Accept() if err != nil { logger.Log(logger.LevelWarn, common.ProtocolSSH, connectionID, "could not accept a channel: %v", err) continue } channelCounter++ // Channels have a type that is dependent on the protocol. For SFTP this is "subsystem" // with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc) go func(in <-chan *ssh.Request, counter int64) { for req := range in { ok := false connID := fmt.Sprintf("%s_%d", connectionID, counter) switch req.Type { case "subsystem": if bytes.Equal(req.Payload[4:], []byte("sftp")) { ok = true sshConnection.UpdateLastActivity() connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, conn.LocalAddr().String(), conn.RemoteAddr().String(), user), ClientVersion: util.BytesToString(sconn.ClientVersion()), RemoteAddr: conn.RemoteAddr(), LocalAddr: conn.LocalAddr(), channel: channel, } go c.handleSftpConnection(channel, connection) } case "exec": // protocol will be set later inside processSSHCommand it could be SSH or SCP connection := Connection{ BaseConnection: common.NewBaseConnection(connID, "sshd_exec", conn.LocalAddr().String(), conn.RemoteAddr().String(), user), ClientVersion: util.BytesToString(sconn.ClientVersion()), RemoteAddr: conn.RemoteAddr(), LocalAddr: conn.LocalAddr(), channel: channel, } ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands) if ok { sshConnection.UpdateLastActivity() } } if req.WantReply { req.Reply(ok, nil) //nolint:errcheck } } }(requests, channelCounter) } } func (c *Configuration) handleSftpConnection(channel ssh.Channel, connection *Connection) { defer func() { if r := recover(); r != nil { logger.Error(logSender, "", "panic in handleSftpConnection: %q stack trace: %v", r, string(debug.Stack())) } }() if err := common.Connections.Add(connection); err != nil { defer connection.CloseFS() //nolint:errcheck errClose := connection.Disconnect() logger.Info(logSender, "", "unable to add connection: %v, close err: %v", err, errClose) return } defer common.Connections.Remove(connection.GetID()) // Create the server instance for the channel using the handler we created above. server := sftp.NewRequestServer(channel, c.createHandlers(connection), sftp.WithStartDirectory(connection.User.Filters.StartDirectory)) defer server.Close() if err := server.Serve(); errors.Is(err, io.EOF) { exitStatus := sshSubsystemExitStatus{Status: uint32(0)} _, err = channel.SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) connection.Log(logger.LevelInfo, "connection closed, sent exit status %+v error: %v", exitStatus, err) } else if err != nil { connection.Log(logger.LevelError, "connection closed with error: %v", err) } } func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers { return sftp.Handlers{ FileGet: connection, FilePut: connection, FileCmd: connection, FileList: connection, } } func canAcceptConnection(ip string) bool { if common.IsBanned(ip, common.ProtocolSSH) { logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %q is banned", ip) return false } if err := common.Connections.IsNewConnectionAllowed(ip, common.ProtocolSSH); err != nil { logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection not allowed from ip %q: %v", ip, err) return false } _, err := common.LimitRate(common.ProtocolSSH, ip) if err != nil { return false } if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil { return false } return true } func discardAllChannels(in <-chan ssh.NewChannel, message, connectionID string) { for req := range in { err := req.Reject(ssh.ConnectionFailed, message) logger.Debug(logSender, connectionID, "discarded channel request, message %q err: %v", message, err) } } func checkAuthError(ip string, err error) { var authErrors *ssh.ServerAuthError if errors.As(err, &authErrors) { // check public key auth errors here for _, err := range authErrors.Errors { var sftpAuthErr *authenticationError if errors.As(err, &sftpAuthErr) { if sftpAuthErr.getLoginMethod() == dataprovider.SSHLoginMethodPublicKey { event := common.HostEventLoginFailed logEv := notifier.LogEventTypeLoginFailed if errors.Is(err, util.ErrNotFound) { event = common.HostEventUserNotFound logEv = notifier.LogEventTypeLoginNoUser } common.AddDefenderEvent(ip, common.ProtocolSSH, event) plugin.Handler.NotifyLogEvent(logEv, common.ProtocolSSH, sftpAuthErr.getUsername(), ip, "", err) return } } } } else { logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTried, common.ProtocolSSH, err.Error()) metric.AddNoAuthTried() common.AddDefenderEvent(ip, common.ProtocolSSH, common.HostEventNoLoginTried) dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTried, ip, common.ProtocolSSH, err) logEv := notifier.LogEventTypeNoLoginTried var negotiationError *ssh.AlgorithmNegotiationError if errors.As(err, &negotiationError) { logEv = notifier.LogEventTypeNotNegotiated } plugin.Handler.NotifyLogEvent(logEv, common.ProtocolSSH, "", ip, "", err) } } func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) { connectionID := "" if conn != nil { connectionID = hex.EncodeToString(conn.SessionID()) } if !filepath.IsAbs(user.HomeDir) { logger.Warn(logSender, connectionID, "user %q has an invalid home dir: %q. Home dir must be an absolute path, login not allowed", user.Username, user.HomeDir) return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir) } if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) { logger.Info(logSender, connectionID, "cannot login user %q, protocol SSH is not allowed", user.Username) return nil, fmt.Errorf("protocol SSH is not allowed for user %q", user.Username) } if user.MaxSessions > 0 { activeSessions := common.Connections.GetActiveSessions(user.Username) if activeSessions >= user.MaxSessions { logger.Info(logSender, "", "authentication refused for user: %q, too many open sessions: %v/%v", user.Username, activeSessions, user.MaxSessions) return nil, fmt.Errorf("too many open sessions: %v", activeSessions) } } if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolSSH) { logger.Info(logSender, connectionID, "cannot login user %q, login method %q is not allowed", user.Username, loginMethod) return nil, fmt.Errorf("login method %q is not allowed for user %q", loginMethod, user.Username) } if user.MustSetSecondFactorForProtocol(common.ProtocolSSH) { logger.Info(logSender, connectionID, "cannot login user %q, second factor authentication is not set", user.Username) return nil, fmt.Errorf("second factor authentication is not set for user %q", user.Username) } remoteAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) if !user.IsLoginFromAddrAllowed(remoteAddr) { logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", user.Username, remoteAddr) return nil, fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, remoteAddr) } if publicKey != "" { loginMethod = fmt.Sprintf("%v: %v", loginMethod, publicKey) } p := &ssh.Permissions{} p.ExtraData = make(map[any]any) p.ExtraData[extraDataUserKey] = *user p.ExtraData[extraDataLoginMethodKey] = loginMethod return p, nil } func (c *Configuration) checkSSHCommands() { if slices.Contains(c.EnabledSSHCommands, "*") { c.EnabledSSHCommands = GetSupportedSSHCommands() return } sshCommands := []string{} for _, command := range c.EnabledSSHCommands { command = strings.TrimSpace(command) if slices.Contains(supportedSSHCommands, command) { sshCommands = append(sshCommands, command) } else { logger.Warn(logSender, "", "unsupported ssh command: %q ignored", command) logger.WarnToConsole("unsupported ssh command: %q ignored", command) } } c.EnabledSSHCommands = sshCommands logger.Debug(logSender, "", "enabled SSH commands %v", c.EnabledSSHCommands) } func (c *Configuration) generateDefaultHostKeys(configDir string) error { var err error defaultHostKeys := []string{defaultPrivateRSAKeyName, defaultPrivateECDSAKeyName, defaultPrivateEd25519KeyName} for _, k := range defaultHostKeys { autoFile := filepath.Join(configDir, k) if _, err = os.Stat(autoFile); errors.Is(err, fs.ErrNotExist) { logger.Info(logSender, "", "No host keys configured and %q does not exist; try to create a new host key", autoFile) logger.InfoToConsole("No host keys configured and %q does not exist; try to create a new host key", autoFile) switch k { case defaultPrivateRSAKeyName: err = util.GenerateRSAKeys(autoFile) case defaultPrivateECDSAKeyName: err = util.GenerateECDSAKeys(autoFile) default: err = util.GenerateEd25519Keys(autoFile) } if err != nil { logger.Warn(logSender, "", "error creating host key %q: %v", autoFile, err) logger.WarnToConsole("error creating host key %q: %v", autoFile, err) return err } } c.HostKeys = append(c.HostKeys, k) } return err } func (c *Configuration) checkHostKeyAutoGeneration(configDir string) error { for _, k := range c.HostKeys { k = strings.TrimSpace(k) if filepath.IsAbs(k) { if _, err := os.Stat(k); errors.Is(err, fs.ErrNotExist) { keyName := filepath.Base(k) switch keyName { case defaultPrivateRSAKeyName: logger.Info(logSender, "", "try to create non-existent host key %q", k) logger.InfoToConsole("try to create non-existent host key %q", k) err = util.GenerateRSAKeys(k) if err != nil { logger.Warn(logSender, "", "error creating host key %q: %v", k, err) logger.WarnToConsole("error creating host key %q: %v", k, err) return err } case defaultPrivateECDSAKeyName: logger.Info(logSender, "", "try to create non-existent host key %q", k) logger.InfoToConsole("try to create non-existent host key %q", k) err = util.GenerateECDSAKeys(k) if err != nil { logger.Warn(logSender, "", "error creating host key %q: %v", k, err) logger.WarnToConsole("error creating host key %q: %v", k, err) return err } case defaultPrivateEd25519KeyName: logger.Info(logSender, "", "try to create non-existent host key %q", k) logger.InfoToConsole("try to create non-existent host key %q", k) err = util.GenerateEd25519Keys(k) if err != nil { logger.Warn(logSender, "", "error creating host key %q: %v", k, err) logger.WarnToConsole("error creating host key %q: %v", k, err) return err } default: logger.Warn(logSender, "", "non-existent host key %q will not be created", k) logger.WarnToConsole("non-existent host key %q will not be created", k) } } } } if len(c.HostKeys) == 0 { if err := c.generateDefaultHostKeys(configDir); err != nil { return err } } return nil } func (c *Configuration) getHostKeyAlgorithms(keyFormat string) []string { var algos []string for _, algo := range algorithmsForKeyFormat(keyFormat) { if slices.Contains(c.HostKeyAlgorithms, algo) { algos = append(algos, algo) } } return algos } // If no host keys are defined we try to use or generate the default ones. func (c *Configuration) checkAndLoadHostKeys(configDir string, serverConfig *ssh.ServerConfig) error { if err := c.checkHostKeyAutoGeneration(configDir); err != nil { return err } hostCertificates, err := c.loadHostCertificates(configDir) if err != nil { return err } serviceStatus.HostKeys = nil for _, hostKey := range c.HostKeys { hostKey = strings.TrimSpace(hostKey) if !util.IsFileInputValid(hostKey) { logger.Warn(logSender, "", "unable to load invalid host key %q", hostKey) logger.WarnToConsole("unable to load invalid host key %q", hostKey) continue } if !filepath.IsAbs(hostKey) { hostKey = filepath.Join(configDir, hostKey) } logger.Info(logSender, "", "Loading private host key %q", hostKey) privateBytes, err := os.ReadFile(hostKey) if err != nil { return err } private, err := ssh.ParsePrivateKey(privateBytes) if err != nil { return err } k := HostKey{ Path: hostKey, Fingerprint: ssh.FingerprintSHA256(private.PublicKey()), Algorithms: c.getHostKeyAlgorithms(private.PublicKey().Type()), } mas, err := ssh.NewSignerWithAlgorithms(private.(ssh.AlgorithmSigner), k.Algorithms) if err != nil { return fmt.Errorf("could not create signer for key %q with algorithms %+v: %w", k.Path, k.Algorithms, err) } serviceStatus.HostKeys = append(serviceStatus.HostKeys, k) logger.Info(logSender, "", "Host key %q loaded, type %q, fingerprint %q, algorithms %+v", hostKey, private.PublicKey().Type(), k.Fingerprint, k.Algorithms) // Add private key to the server configuration. serverConfig.AddHostKey(mas) for _, cert := range hostCertificates { signer, err := ssh.NewCertSigner(cert.Certificate, mas) if err == nil { var algos []string for _, algo := range algorithmsForKeyFormat(signer.PublicKey().Type()) { if underlyingAlgo, ok := certKeyAlgoNames[algo]; ok { if slices.Contains(mas.Algorithms(), underlyingAlgo) { algos = append(algos, algo) } } } serviceStatus.HostKeys = append(serviceStatus.HostKeys, HostKey{ Path: cert.Path, Fingerprint: ssh.FingerprintSHA256(signer.PublicKey()), Algorithms: algos, }) serverConfig.AddHostKey(signer) logger.Info(logSender, "", "Host certificate loaded for host key %q, fingerprint %q, algorithms %+v", hostKey, ssh.FingerprintSHA256(signer.PublicKey()), algos) } } } var fp []string for idx := range serviceStatus.HostKeys { h := &serviceStatus.HostKeys[idx] fp = append(fp, h.Fingerprint) } vfs.SetSFTPFingerprints(fp) return nil } func (c *Configuration) loadHostCertificates(configDir string) ([]hostCertificate, error) { var certs []hostCertificate for _, certPath := range c.HostCertificates { certPath = strings.TrimSpace(certPath) if !util.IsFileInputValid(certPath) { logger.Warn(logSender, "", "unable to load invalid host certificate %q", certPath) logger.WarnToConsole("unable to load invalid host certificate %q", certPath) continue } if !filepath.IsAbs(certPath) { certPath = filepath.Join(configDir, certPath) } certBytes, err := os.ReadFile(certPath) if err != nil { return certs, fmt.Errorf("unable to load host certificate %q: %w", certPath, err) } parsed, _, _, _, err := ssh.ParseAuthorizedKey(certBytes) if err != nil { return nil, fmt.Errorf("unable to parse host certificate %q: %w", certPath, err) } cert, ok := parsed.(*ssh.Certificate) if !ok { return nil, fmt.Errorf("the file %q is not an SSH certificate", certPath) } if cert.CertType != ssh.HostCert { return nil, fmt.Errorf("the file %q is not an host certificate", certPath) } certs = append(certs, hostCertificate{ Path: certPath, Certificate: cert, }) } return certs, nil } func (c *Configuration) initializeOPKSSH() error { if c.OPKSSHPath != "" { if len(c.parsedUserCAKeys) > 0 { return errors.New("opkssh and certificate authorities are mutually exclusive") } if !util.IsFileInputValid(c.OPKSSHPath) || !filepath.IsAbs(c.OPKSSHPath) { return fmt.Errorf("opkssh path %q is not valid, it must be an absolute path", c.OPKSSHPath) } if c.OPKSSHChecksum == "" { if _, err := os.Stat(c.OPKSSHPath); err != nil { return fmt.Errorf("error validating opkssh path %q: %w", c.OPKSSHPath, err) } } else { if err := util.VerifyFileChecksum(c.OPKSSHPath, sha256.New(), c.OPKSSHChecksum, 100*1024*1024); err != nil { return fmt.Errorf("error validating opkssh checksum: %w", err) } } } return nil } func (c *Configuration) verifyWithOPKSSH(username string, cert *ssh.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() args := []string{"verify", username, util.BytesToString(ssh.MarshalAuthorizedKey(cert)), cert.Type()} out, err := c.executor.CombinedOutput(ctx, c.OPKSSHPath, args...) if err != nil { logger.Debug(logSender, "", "unable to execute opk verifier: %s", string(out)) return fmt.Errorf("unable to execute opk verifier: %w", err) } pubKey, _, _, _, err := ssh.ParseAuthorizedKey(out) //nolint:dogsled if err != nil { logger.Debug(logSender, "", "unable to validate the opk verifier output: %s", string(out)) return fmt.Errorf("unable to validate the opk verifier output: %w", err) } if !bytes.Equal(pubKey.Marshal(), cert.SignatureKey.Marshal()) { return errors.New("unable to validate opk result") } return nil } func (c *Configuration) initializeCertChecker(configDir string) error { for _, keyPath := range c.TrustedUserCAKeys { keyPath = strings.TrimSpace(keyPath) if !util.IsFileInputValid(keyPath) { logger.Warn(logSender, "", "unable to load invalid trusted user CA key %q", keyPath) logger.WarnToConsole("unable to load invalid trusted user CA key %q", keyPath) continue } if !filepath.IsAbs(keyPath) { keyPath = filepath.Join(configDir, keyPath) } keyBytes, err := os.ReadFile(keyPath) if err != nil { logger.Warn(logSender, "", "error loading trusted user CA key %q: %v", keyPath, err) logger.WarnToConsole("error loading trusted user CA key %q: %v", keyPath, err) return err } parsedKey, _, _, _, err := ssh.ParseAuthorizedKey(keyBytes) if err != nil { logger.Warn(logSender, "", "error parsing trusted user CA key %q: %v", keyPath, err) logger.WarnToConsole("error parsing trusted user CA key %q: %v", keyPath, err) return err } c.parsedUserCAKeys = append(c.parsedUserCAKeys, parsedKey) } c.certChecker = &ssh.CertChecker{ SupportedCriticalOptions: []string{ sourceAddressCriticalOption, }, IsUserAuthority: func(k ssh.PublicKey) bool { for _, key := range c.parsedUserCAKeys { if bytes.Equal(k.Marshal(), key.Marshal()) { return true } } return false }, } if c.RevokedUserCertsFile != "" { if !util.IsFileInputValid(c.RevokedUserCertsFile) { return fmt.Errorf("invalid revoked user certificate: %q", c.RevokedUserCertsFile) } if !filepath.IsAbs(c.RevokedUserCertsFile) { c.RevokedUserCertsFile = filepath.Join(configDir, c.RevokedUserCertsFile) } } revokedCertManager.filePath = c.RevokedUserCertsFile return revokedCertManager.load() } func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error { err := &ssh.PartialSuccessError{} if c.PasswordAuthentication && slices.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) { err.Next.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { return c.validatePasswordCredentials(conn, password, dataprovider.SSHLoginMethodKeyAndPassword) } } if c.KeyboardInteractiveAuthentication && slices.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) { err.Next.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt, true) } } return err } func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { var user dataprovider.User var certPerm *ssh.Permissions method := dataprovider.SSHLoginMethodPublicKey ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) cert, ok := pubKey.(*ssh.Certificate) var certFingerprint string if ok { certFingerprint = ssh.FingerprintSHA256(cert.Key) if c.OPKSSHPath != "" { if err := c.verifyWithOPKSSH(conn.User(), cert); err != nil { err := fmt.Errorf("ssh: verification with OPK failed: %v", err) user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) return nil, err } } else { if cert.CertType != ssh.UserCert { err := fmt.Errorf("ssh: cert has type %d", cert.CertType) user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) return nil, err } if !c.certChecker.IsUserAuthority(cert.SignatureKey) { err := errors.New("ssh: certificate signed by unrecognized authority") user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) return nil, err } if len(cert.ValidPrincipals) == 0 { err := fmt.Errorf("ssh: certificate %s has no valid principals, user: \"%s\"", certFingerprint, conn.User()) user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) return nil, err } if revokedCertManager.isRevoked(certFingerprint) { err := fmt.Errorf("ssh: certificate %s is revoked", certFingerprint) user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) return nil, err } if err := c.certChecker.CheckCert(conn.User(), cert); err != nil { user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) return nil, err } } certPerm = &cert.Permissions } user, keyID, err := dataprovider.CheckUserAndPubKey(conn.User(), pubKey.Marshal(), ipAddr, common.ProtocolSSH, ok) if err != nil { user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) return nil, err } if ok { keyID = fmt.Sprintf("%s: ID: %s, serial: %v, CA %s %s", certFingerprint, cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey)) } if certPerm == nil { certPerm = &ssh.Permissions{} } certPerm.ExtraData = make(map[any]any) certPerm.ExtraData[extraDataKeyIDKey] = keyID certPerm.ExtraData[extraDataUserKey] = user if user.IsPartialAuth() { certPerm.ExtraData[extraDataPartialSuccessErrKey] = c.getPartialSuccessError(user.GetNextAuthMethods()) } return certPerm, nil } func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte, method string) (*ssh.Permissions, error) { var err error var user dataprovider.User var sshPerm *ssh.Permissions ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) if user, err = dataprovider.CheckUserAndPass(conn.User(), util.BytesToString(pass), ipAddr, common.ProtocolSSH); err == nil { sshPerm, err = loginUser(&user, method, "", conn) } user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) if err != nil { return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err), method, conn.User()) } return sshPerm, nil } func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge, method string, isPartialAuth bool, ) (*ssh.Permissions, error) { var err error var user dataprovider.User var sshPerm *ssh.Permissions ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client, ipAddr, common.ProtocolSSH, isPartialAuth); err == nil { sshPerm, err = loginUser(&user, method, "", conn) } user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) if err != nil { return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err), method, conn.User()) } return sshPerm, nil } func updateLoginMetrics(user *dataprovider.User, ip, method string, err error) { metric.AddLoginAttempt(method) if err == nil { plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, common.ProtocolSSH, user.Username, ip, "", err) common.DelayLogin(nil) } else { logger.ConnectionFailedLog(user.Username, ip, method, common.ProtocolSSH, err.Error()) if method != dataprovider.SSHLoginMethodPublicKey { // some clients try all available public keys for a user, we // record failed login key auth only once for session if the // authentication fails in checkAuthError event := common.HostEventLoginFailed logEv := notifier.LogEventTypeLoginFailed if errors.Is(err, util.ErrNotFound) { event = common.HostEventUserNotFound logEv = notifier.LogEventTypeLoginNoUser } common.AddDefenderEvent(ip, common.ProtocolSSH, event) plugin.Handler.NotifyLogEvent(logEv, common.ProtocolSSH, user.Username, ip, "", err) if method != dataprovider.SSHLoginMethodPublicKey { common.DelayLogin(err) } } } metric.AddLoginResult(method, err) dataprovider.ExecutePostLoginHook(user, method, ip, common.ProtocolSSH, err) } type revokedCertificates struct { filePath string mu sync.RWMutex certs map[string]bool } func (r *revokedCertificates) load() error { if r.filePath == "" { return nil } logger.Debug(logSender, "", "loading revoked user certificate file %q", r.filePath) info, err := os.Stat(r.filePath) if err != nil { return fmt.Errorf("unable to load revoked user certificate file %q: %w", r.filePath, err) } maxSize := int64(1048576 * 5) // 5MB if info.Size() > maxSize { return fmt.Errorf("unable to load revoked user certificate file %q size too big: %v/%v bytes", r.filePath, info.Size(), maxSize) } content, err := os.ReadFile(r.filePath) if err != nil { return fmt.Errorf("unable to read revoked user certificate file %q: %w", r.filePath, err) } var certs []string err = json.Unmarshal(content, &certs) if err != nil { return fmt.Errorf("unable to parse revoked user certificate file %q: %w", r.filePath, err) } r.mu.Lock() defer r.mu.Unlock() r.certs = map[string]bool{} for _, fp := range certs { r.certs[fp] = true } logger.Debug(logSender, "", "revoked user certificate file %q loaded, entries: %v", r.filePath, len(r.certs)) return nil } func (r *revokedCertificates) isRevoked(fp string) bool { r.mu.RLock() defer r.mu.RUnlock() return r.certs[fp] } // Reload reloads the list of revoked user certificates func Reload() error { return revokedCertManager.load() } func algorithmsForKeyFormat(keyFormat string) []string { switch keyFormat { case ssh.KeyAlgoRSA: return []string{ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512, ssh.KeyAlgoRSA} case ssh.CertAlgoRSAv01: return []string{ssh.CertAlgoRSASHA256v01, ssh.CertAlgoRSASHA512v01, ssh.CertAlgoRSAv01} default: return []string{keyFormat} } } ================================================ FILE: internal/sftpd/sftpd.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package sftpd implements the SSH File Transfer Protocol as described in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02. // It uses pkg/sftp library: // https://github.com/pkg/sftp package sftpd import ( "strings" "time" "golang.org/x/crypto/ssh" ) const ( logSender = "sftpd" handshakeTimeout = 2 * time.Minute ) var ( supportedSSHCommands = []string{"scp", "md5sum", "sha1sum", "sha256sum", "sha384sum", "sha512sum", "cd", "pwd", "sftpgo-copy", "sftpgo-remove"} defaultSSHCommands = []string{"md5sum", "sha1sum", "sha256sum", "cd", "pwd", "scp"} sshHashCommands = []string{"md5sum", "sha1sum", "sha256sum", "sha384sum", "sha512sum"} serviceStatus ServiceStatus certKeyAlgoNames = map[string]string{ ssh.CertAlgoRSAv01: ssh.KeyAlgoRSA, ssh.CertAlgoRSASHA256v01: ssh.KeyAlgoRSASHA256, ssh.CertAlgoRSASHA512v01: ssh.KeyAlgoRSASHA512, ssh.InsecureCertAlgoDSAv01: ssh.InsecureKeyAlgoDSA, //nolint:staticcheck ssh.CertAlgoECDSA256v01: ssh.KeyAlgoECDSA256, ssh.CertAlgoECDSA384v01: ssh.KeyAlgoECDSA384, ssh.CertAlgoECDSA521v01: ssh.KeyAlgoECDSA521, ssh.CertAlgoSKECDSA256v01: ssh.KeyAlgoSKECDSA256, ssh.CertAlgoED25519v01: ssh.KeyAlgoED25519, ssh.CertAlgoSKED25519v01: ssh.KeyAlgoSKED25519, } ) type sshSubsystemExitStatus struct { Status uint32 } type sshSubsystemExecMsg struct { Command string } type hostCertificate struct { Certificate *ssh.Certificate Path string } // HostKey defines the details for a used host key type HostKey struct { Path string `json:"path"` Fingerprint string `json:"fingerprint"` Algorithms []string `json:"algorithms"` } // GetAlgosAsString returns the host key algorithms as comma separated string func (h *HostKey) GetAlgosAsString() string { return strings.Join(h.Algorithms, ", ") } // ServiceStatus defines the service status type ServiceStatus struct { IsActive bool `json:"is_active"` Bindings []Binding `json:"bindings"` SSHCommands []string `json:"ssh_commands"` HostKeys []HostKey `json:"host_keys"` Authentications []string `json:"authentications"` MACs []string `json:"macs"` KexAlgorithms []string `json:"kex_algorithms"` Ciphers []string `json:"ciphers"` PublicKeyAlgorithms []string `json:"public_key_algorithms"` } // GetSSHCommandsAsString returns enabled SSH commands as comma separated string func (s *ServiceStatus) GetSSHCommandsAsString() string { return strings.Join(s.SSHCommands, ", ") } // GetSupportedAuthsAsString returns the supported authentications as comma separated string func (s *ServiceStatus) GetSupportedAuthsAsString() string { return strings.Join(s.Authentications, ", ") } // GetMACsAsString returns the enabled MAC algorithms as comma separated string func (s *ServiceStatus) GetMACsAsString() string { return strings.Join(s.MACs, ", ") } // GetKEXsAsString returns the enabled KEX algorithms as comma separated string func (s *ServiceStatus) GetKEXsAsString() string { return strings.Join(s.KexAlgorithms, ", ") } // GetCiphersAsString returns the enabled ciphers as comma separated string func (s *ServiceStatus) GetCiphersAsString() string { return strings.Join(s.Ciphers, ", ") } // GetPublicKeysAlgosAsString returns enabled public key authentication // algorithms as comma separated string func (s *ServiceStatus) GetPublicKeysAlgosAsString() string { return strings.Join(s.PublicKeyAlgorithms, ", ") } // GetStatus returns the server status func GetStatus() ServiceStatus { return serviceStatus } // GetDefaultSSHCommands returns the SSH commands enabled as default func GetDefaultSSHCommands() []string { result := make([]string, len(defaultSSHCommands)) copy(result, defaultSSHCommands) return result } // GetSupportedSSHCommands returns the supported SSH commands func GetSupportedSSHCommands() []string { result := make([]string, len(supportedSSHCommands)) copy(result, supportedSSHCommands) return result } ================================================ FILE: internal/sftpd/sftpd_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd_test import ( "bufio" "bytes" "context" "crypto/rand" "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/binary" "encoding/hex" "encoding/json" "errors" "fmt" "hash" "io" "io/fs" "math" "net" "net/http" "os" "os/exec" "path" "path/filepath" "runtime" "slices" "strconv" "strings" "sync" "sync/atomic" "testing" "time" _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v5/stdlib" _ "github.com/mattn/go-sqlite3" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/pkg/sftp" "github.com/rs/zerolog" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpdtest" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( logSender = "sftpdTesting" sftpServerAddr = "127.0.0.1:2022" sftpSrvAddr2222 = "127.0.0.1:2222" defaultUsername = "test_user_sftp" defaultPassword = "test_password" defaultSFTPUsername = "test_sftpfs_user" testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" testPubKey1 = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCd60+/j+y8f0tLftihWV1YN9RSahMI9btQMDIMqts/jeNbD8jgoogM3nhF7KxfcaMKURuD47KC4Ey6iAJUJ0sWkSNNxOcIYuvA+5MlspfZDsa8Ag76Fe1vyz72WeHMHMeh/hwFo2TeIeIXg480T1VI6mzfDrVp2GzUx0SS0dMsQBjftXkuVR8YOiOwMCAH2a//M1OrvV7d/NBk6kBN0WnuIBb2jKm15PAA7+jQQG7tzwk2HedNH3jeL5GH31xkSRwlBczRK0xsCQXehAlx6cT/e/s44iJcJTHfpPKoSk6UAhPJYe7Z1QnuoawY9P9jQaxpyeImBZxxUEowhjpj2avBxKdRGBVK8R7EL8tSOeLbhdyWe5Mwc1+foEbq9Zz5j5Kd+hn3Wm1UnsGCrXUUUoZp1jnlNl0NakCto+5KmqnT9cHxaY+ix2RLUWAZyVFlRq71OYux1UHJnEJPiEI1/tr4jFBSL46qhQZv/TfpkfVW8FLz0lErfqu0gQEZnNHr3Fc= nicola@p1" testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn NhAAAAAwEAAQAAAYEAtN449A/nY5O6cSH/9Doa8a3ISU0WZJaHydTaCLuO+dkqtNpnV5mq zFbKidXAI1eSwVctw9ReVOl1uK6aZF3lbXdOD8W9PXobR9KUUT2qBx5QC4ibfAqDKWymDA PG9ylzz64hsYBqJr7VNk9kTFEUsDmWzLabLoH42Elnp8mF/lTkWIcpVp0ly/etS08gttXo XenekJ1vRuxOYWDCEzGPU7kGc920TmM14k7IDdPoOh5+3sRUKedKeOUrVDH1f0n7QjHQsZ cbshp8tgqzf734zu8cTqNrr+6taptdEOOij1iUL/qYGfzny/hA48tO5+UFUih5W8ftp0+E NBIDkkGgk2MJ92I7QAXyMVsIABXco+mJT7pQi9tqlODGIQ3AOj0gcA3X/Ib8QX77Ih3TPi XEh77/P1XiYZOgpp2cRmNH8QbqaL9u898hDvJwIPJPuj2lIltTElH7hjBf5LQfCzrLV7BD 10rM7sl4jr+A2q8jl1Ikp+25kainBBZSbrDummT9AAAFgDU/VLk1P1S5AAAAB3NzaC1yc2 EAAAGBALTeOPQP52OTunEh//Q6GvGtyElNFmSWh8nU2gi7jvnZKrTaZ1eZqsxWyonVwCNX ksFXLcPUXlTpdbiummRd5W13Tg/FvT16G0fSlFE9qgceUAuIm3wKgylspgwDxvcpc8+uIb GAaia+1TZPZExRFLA5lsy2my6B+NhJZ6fJhf5U5FiHKVadJcv3rUtPILbV6F3p3pCdb0bs TmFgwhMxj1O5BnPdtE5jNeJOyA3T6Doeft7EVCnnSnjlK1Qx9X9J+0Ix0LGXG7IafLYKs3 +9+M7vHE6ja6/urWqbXRDjoo9YlC/6mBn858v4QOPLTuflBVIoeVvH7adPhDQSA5JBoJNj CfdiO0AF8jFbCAAV3KPpiU+6UIvbapTgxiENwDo9IHAN1/yG/EF++yId0z4lxIe+/z9V4m GToKadnEZjR/EG6mi/bvPfIQ7ycCDyT7o9pSJbUxJR+4YwX+S0Hws6y1ewQ9dKzO7JeI6/ gNqvI5dSJKftuZGopwQWUm6w7ppk/QAAAAMBAAEAAAGAHKnC+Nq0XtGAkIFE4N18e6SAwy 0WSWaZqmCzFQM0S2AhJnweOIG/0ZZHjsRzKKauOTmppQk40dgVsejpytIek9R+aH172gxJ 2n4Cx0UwduRU5x8FFQlNc/kl722B0JWfJuB/snOZXv6LJ4o5aObIkozt2w9tVFeAqjYn2S 1UsNOfRHBXGsTYwpRDwFWP56nKo2d2wBBTHDhCy6fb2dLW1fvSi/YspueOGIlHpvlYKi2/ CWqvs9xVrwcScMtiDoQYq0khhO0efLCxvg/o+W9CLMVM2ms4G1zoSUQKN0oYWWQJyW4+VI YneWO8UpN0J3ElXKi7bhgAat7dBaM1g9IrAzk153DiEFZNsPxGOgL/+YdQN7zUBx/z7EkI jyv80RV7fpUXvcq2p+qNl6UVig3VSzRrnsaJkUWu/A0u59ha7ocv6NxDIXjxpIDJme16GF quiGVBQNnYJymS/vFEbGf6bgf7iRmMCRUMG4nqLA6fPYP9uAtch+CmDfVLZC/fIdC5AAAA wQCDissV4zH6bfqgxJSuYNk8Vbb+19cF3b7gH1rVlB3zxpCAgcRgMHC+dP1z2NRx7UW9MR nye6kjpkzZZ0OigLqo7TtEq8uTglD9o6W7mRXqhy5A/ySOmqPL3ernHHQhGuoNODYAHkOU u2Rh8HXi+VLwKZcLInPOYJvcuLG4DxN8WfeVvlMHwhAOaTNNOtL4XZDHQeIPc4qHmJymmv sV7GuyQ6yW5C10uoGdxRPd90Bh4z4h2bKfZFjvEBbSBVkqrlAAAADBAN/zNtNayd/dX7Cr Nb4sZuzCh+CW4BH8GOePZWNCATwBbNXBVb5cR+dmuTqYm+Ekz0VxVQRA1TvKncluJOQpoa Xj8r0xdIgqkehnfDPMKtYVor06B9Fl1jrXtXU0Vrr6QcBWruSVyK1ZxqcmcNK/+KolVepe A6vcl/iKaG4U7su166nxLST06M2EgcSVsFJHpKn5+WAXC+X0Gx8kNjWIIb3GpiChdc0xZD mq02xZthVJrTCVw/e7gfDoB2QRsNV8HwAAAMEAzsCghZVp+0YsYg9oOrw4tEqcbEXEMhwY 0jW8JNL8Spr1Ibp5Dw6bRSk5azARjmJtnMJhJ3oeHfF0eoISqcNuQXGndGQbVM9YzzAzc1 NbbCNsVroqKlChT5wyPNGS+phi2bPARBno7WSDvshTZ7dAVEP2c9MJW0XwoSevwKlhgSdt RLFFQ/5nclJSdzPBOmQouC0OBcMFSrYtMeknJ4VvueVvve5HcHFaEsaMc7ABAGaLYaBQOm iixITGvaNZh/tjAAAACW5pY29sYUBwMQE= -----END OPENSSH PRIVATE KEY-----` // password protected private key testPrivateKeyPwd = `-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABAvfwQQcs +PyMsCLTNFcKiQAAAAEAAAAAEAAAAzAAAAC3NzaC1lZDI1NTE5AAAAILqltfCL7IPuIQ2q +8w23flfgskjIlKViEwMfjJR4mrbAAAAkHp5xgG8J1XW90M/fT59ZUQht8sZzzP17rEKlX waYKvLzDxkPK6LFIYs55W1EX1eVt/2Maq+zQ7k2SOUmhPNknsUOlPV2gytX3uIYvXF7u2F FTBIJuzZ+UQ14wFbraunliE9yye9DajVG1kz2cz2wVgXUbee+gp5NyFVvln+TcTxXwMsWD qwlk5iw/jQekxThg== -----END OPENSSH PRIVATE KEY----- ` testPubKeyPwd = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILqltfCL7IPuIQ2q+8w23flfgskjIlKViEwMfjJR4mrb" privateKeyPwd = "password" // test CA user key. // % ssh-keygen -f ca_user_key testCAUserKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDF5fcwZHiyixmnE6IlOZJpZhWXoh62gN+yadAA0GJ509SAEaZVLPDP8S5RsE8mUikR3wxynVshxHeqMhrkS+RlNbhSlOXDdNg94yTrq/xF8Z/PgKRInvef74k5i7bAIytza7jERzFJ/ujTEy3537T5k5EYQJ15ZQGuvzynSdv+6o99SjI4jFplyQOZ2QcYbEAmhHm5GgQlIiEFG/RlDtLksOulKZxOY3qPzP0AyQxtZJXn/5vG40aW9LTbwxCJqWlgrkFXMqAAVCbuU5YspwhiXmKt1PsldiXw23oloa4caCKN1jzbFiGuZNXEU2Ebx7JIvjQCPaUYwLjEbkRDxDqN/vmwZqBuKYiuG9Eafx+nFSQkr7QYb5b+mT+/1IFHnmeRGn38731kBqtH7tpzC/t+soRX9p2HtJM+9MYhblO2OqTSPGTlxihWUkyiRBekpAhaiHld16TsG+A3bOJHrojGcX+5g6oGarKGLAMcykL1X+rZqT993Mo6d2Z7q43MOXE= root@p1" // this is testPubKey signed using testCAUserKey. // % ssh-keygen -s ca_user_key -I test_user_sftp -n test_user_sftp -V always:forever -O source-address=127.0.0.1 -z 1 /tmp/test.pub testCertValid = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgm2fil1IIoTixrA2QE9tk7Vbspj/JdEY90e3K2htxYv8AAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAQAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAAAAAAD//////////wAAACMAAAAOc291cmNlLWFkZHJlc3MAAAANAAAACTEyNy4wLjAuMQAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAZcAAAAHc3NoLXJzYQAAAAMBAAEAAAGBAMXl9zBkeLKLGacToiU5kmlmFZeiHraA37Jp0ADQYnnT1IARplUs8M/xLlGwTyZSKRHfDHKdWyHEd6oyGuRL5GU1uFKU5cN02D3jJOur/EXxn8+ApEie95/viTmLtsAjK3NruMRHMUn+6NMTLfnftPmTkRhAnXllAa6/PKdJ2/7qj31KMjiMWmXJA5nZBxhsQCaEebkaBCUiIQUb9GUO0uSw66UpnE5jeo/M/QDJDG1klef/m8bjRpb0tNvDEImpaWCuQVcyoABUJu5TliynCGJeYq3U+yV2JfDbeiWhrhxoIo3WPNsWIa5k1cRTYRvHski+NAI9pRjAuMRuREPEOo3++bBmoG4piK4b0Rp/H6cVJCSvtBhvlv6ZP7/UgUeeZ5EaffzvfWQGq0fu2nML+36yhFf2nYe0kz70xiFuU7Y6pNI8ZOXGKFZSTKJEF6SkCFqIeV3XpOwb4Dds4keuiMZxf7mDqgZqsoYsAxzKQvVf6tmpP33cyjp3Znurjcw5cQAAAZQAAAAMcnNhLXNoYTItNTEyAAABgMNenD7d1J9cF7JWgHA1DYpJ5+5knPtdXbbIgZAznsTxX7qOdptjeeYOuzhQ5Bwklh3fjewiJpGR1rBqbULP+6PAKeYqd7dNLH/upfKBfJweRf5pdXDpoknHaVuIhi4Uu6FeI4NkAzX9nqNKjFAflhJ+7GLGkLNb0UVZxgxr/t0rPmxc5iTg2ZRM+rk1Ij0S5RnGiKVsdAClqNA6h4TDzu5lJVdK5XvuNKBsKVRCvsVBOgJQTtRTLywQaqWR+HBfCiMj8X8EI7atDlJ6XIAlTLOO/f1sM8QPLjT0+tCHZaGFzg/lKPh3/yFQ4MvddZCptMy1Ll1xvj7cz2ynhGR4PiDfikV3YzgJU/KtL5y+ZB4jU08oPRiOP612PjwZZ+MqYOVOFCKUpMpZQs5UJHME+zNKr4LEj8M0x4YFKIciC+RsrCo4ujbJHmz61ionCadU+fmngvl3C3QjmUdgULBevODeUeIpJv4yFahNxrG1SKRTAa8VVDwJ9GdDTtmXM0mrwA== nicola@p1" // this is testPubKey signed using a CA user key different from testCAUserKey testCertUntrustedCA = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg8oFPWpjYy/DowMmtOjWj7Dq20d2N/4Rxzr/c710tOOUAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAAAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAAAAAAD//////////wAAAAAAAACCAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFnBlcm1pdC1wb3J0LWZvcndhcmRpbmcAAAAAAAAACnBlcm1pdC1wdHkAAAAAAAAADnBlcm1pdC11c2VyLXJjAAAAAAAAAAAAAAGXAAAAB3NzaC1yc2EAAAADAQABAAABgQCqgm2gVlptULThfpRR0oCb4SAU3368ULlJaiZOUdq6b94KTfgmu4hTLs7u3a8hyZnVxrKrJ93uAVCwa/HGtgiN96CNC6JUt/QnPqTJ8LQ207RdoE9fbOe6mGwOle5z45+5JFoIi5ZZuD8JsBGodVoa92UepoMyBcNtZyl9q2GP4yT2tIYRon79dtG9AXiDYyhSgePqaObN67dn3ivMc4ZGNukK3cG07cYPic5y0wxX16wSMG3pGQDyUkAu+s4AqpnV9EWHM4PE7SYkCXE99++tUK3QALYqvGZKrLHgzmDKi6n+e14vHYUppAeGDZzwlawiY4oGP9eOW2KUfjZe2ZeL22JTFDYzH2lNV2WtUpeKRGGTSGaUblRVC9hRt6hKCT4c7qpW4kO4kPhE39JpcNPGLql7srNkw+3xXBs8xghMPtH3nOl1Rz2mxnX5tAqmPBb+KiPepnrs+pBRu7i+nCVp8az+iN87STYHy+zPtvTR+QURC8BpNraPOfXwpwM2HaMAAAGUAAAADHJzYS1zaGEyLTUxMgAAAYBnTXCL6tXUO3/Gtsm7lnH9Sulzca8FOoI4Y/4bVYhq4iUNu7Ca452m+Xr9qmCEoIyIJF0LEEcJ8jcS4rfX15e7tNNoknv7JbYXBFAbp1Y/76iqVf89FjfVcbEyH2ToAf7eyQAWzQ3gEKS8mQIkLnAwmCboUXC4GRodSIiOXiTt5Q6T02MVc8TxkhmlTA0uVLd5XgstySgE/oLBnL59lhJcwQmdhHL+m480+PaW55CtMuC36RTwk/tOyuWCDC5qMXnoveNB3yu45o3L/U4hoyJ0/5FyP5C8ahgydY0LoRZQG/mNzuraY4433rK+IfkQvZTyaDtcjhxE6hCD5F40aDDh88i6XaKAPikD6fqra6BN8PoPgLuRHzOJuqsMXBWM99s7qPgSnBbmXlekz/1jvvFiCh3zvAFTxFz2KyE4+SbDcCrhpxkNL7idw6r/ZsHaI/2+zhDcgSs5MgBwYLJEj6zUqVdp5XsF8YfC7yNZV5/qy68qY2+zXrC57SPifU2SCPE= nicola@p1" // this is testPubKey signed as host certificate. // % ssh-keygen -s ca_user_key -I test_user_sftp -h -n test_user_sftp -V always:forever -z 2 /tmp/test.pub testHostCert = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg7O2LDpLO1jGTX3SSzEMILoAYJb9DdggyyaUMXUUg3L4AAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAgAAAAIAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAAAAAAD//////////wAAAAAAAAAAAAAAAAAAAZcAAAAHc3NoLXJzYQAAAAMBAAEAAAGBAMXl9zBkeLKLGacToiU5kmlmFZeiHraA37Jp0ADQYnnT1IARplUs8M/xLlGwTyZSKRHfDHKdWyHEd6oyGuRL5GU1uFKU5cN02D3jJOur/EXxn8+ApEie95/viTmLtsAjK3NruMRHMUn+6NMTLfnftPmTkRhAnXllAa6/PKdJ2/7qj31KMjiMWmXJA5nZBxhsQCaEebkaBCUiIQUb9GUO0uSw66UpnE5jeo/M/QDJDG1klef/m8bjRpb0tNvDEImpaWCuQVcyoABUJu5TliynCGJeYq3U+yV2JfDbeiWhrhxoIo3WPNsWIa5k1cRTYRvHski+NAI9pRjAuMRuREPEOo3++bBmoG4piK4b0Rp/H6cVJCSvtBhvlv6ZP7/UgUeeZ5EaffzvfWQGq0fu2nML+36yhFf2nYe0kz70xiFuU7Y6pNI8ZOXGKFZSTKJEF6SkCFqIeV3XpOwb4Dds4keuiMZxf7mDqgZqsoYsAxzKQvVf6tmpP33cyjp3Znurjcw5cQAAAZQAAAAMcnNhLXNoYTItNTEyAAABgHlAWMTTzNrE6pxHlkr09ZXsHgJi8U2p7eifs56DOLgklYIXVUJPEEcnzMKGdpPBnqJsvg3+PccqxgOr5L1dFuOmekQ/dGiHd1enrESiGvJOvDfm0WsuBjxEZkSNFWgC9Z2NltToMmRlhVBmb4ZRZtAmi9DAFlJ/BDV4t8ikXZ5oUsigwIeOeLkdPFx3C3x9KZIpuwuAIV4Nfmz75q1NMWY2K1hv682QCKwMYqOWSotz1vWunNmZ0yPRl9UwqAq+nqwO3AApnlrQ3MmHujWQ5tl65PyhfpI8oghhUtB6YrJIAuRXNI/S0+KewCpiYm7nbFBtv9lpecujxAeTibYBrFZ5VODEUm3sdQ/HMdTmkhi6xNgPDQVlvKFqBJAaqoO3tbhKTbEZ865tJMqhyxmZ4XY08wduvSVobrNr7s3rm42/FXLIpung+UOVXonHyeIv9zQ0iJ/bvqKQ1fOsTisZdcD0lz80ZGsjdgJt7yKfUNBnAyVbTXm048E3WsZslJIYCA== nicola@p1" // this is testPubKey signed using testCAUserKey but with source address 172.16.34.45. // % ssh-keygen -s ca_user_key -I test_user_sftp -n test_user_sftp -V always:forever -O source-address=172.16.34.45 -z 3 /tmp/test.pub testCertOtherSourceAddress = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgZ4Su0250R4sQRNYJqJH9VTp9OyeYMAvqY5+lJRI4LzMAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAwAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAAAAAAD//////////wAAACYAAAAOc291cmNlLWFkZHJlc3MAAAAQAAAADDE3Mi4xNi4zNC40NQAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAZcAAAAHc3NoLXJzYQAAAAMBAAEAAAGBAMXl9zBkeLKLGacToiU5kmlmFZeiHraA37Jp0ADQYnnT1IARplUs8M/xLlGwTyZSKRHfDHKdWyHEd6oyGuRL5GU1uFKU5cN02D3jJOur/EXxn8+ApEie95/viTmLtsAjK3NruMRHMUn+6NMTLfnftPmTkRhAnXllAa6/PKdJ2/7qj31KMjiMWmXJA5nZBxhsQCaEebkaBCUiIQUb9GUO0uSw66UpnE5jeo/M/QDJDG1klef/m8bjRpb0tNvDEImpaWCuQVcyoABUJu5TliynCGJeYq3U+yV2JfDbeiWhrhxoIo3WPNsWIa5k1cRTYRvHski+NAI9pRjAuMRuREPEOo3++bBmoG4piK4b0Rp/H6cVJCSvtBhvlv6ZP7/UgUeeZ5EaffzvfWQGq0fu2nML+36yhFf2nYe0kz70xiFuU7Y6pNI8ZOXGKFZSTKJEF6SkCFqIeV3XpOwb4Dds4keuiMZxf7mDqgZqsoYsAxzKQvVf6tmpP33cyjp3Znurjcw5cQAAAZQAAAAMcnNhLXNoYTItNTEyAAABgL34Q3Li8AJIxZLU+fh4i8ehUWpm31vEvlNjXVCeP70xI+5hWuEt6E1TgKw7GCL5GeD4KehX4vVcNs+A2eOdIUZfDBZIFxn88BN8xcMlDpAMJXgvNqGttiOwcspL6X3N288djUgpCI718lLRdz8nvFqcuYBhSpBm5KL4JzH5o1o8yqv75wMJsH8CJYwGhvWi0OgWOqaLRAk3IUxq3Fbgo/nX11NgrkY/dHIZCkGBFaLJ/M5mfmt/K/5hJAVgLcSxMwB/ryyGaziB9Pv7CwZ9uwnMoRcAvyr96lqgdtLt7LNY8ktugAJ7EnBWjQn4+EJAjjRK2sCaiwpdP37ckDZgmk0OWGEL1yVy8VXgl9QBd7Mb1EVl+lhRyw8jlgBXZOGqpdDrmKCdBYGtU7ujyndLXmxZEAlqhef0yCsyZPTkYH3RhjCYs8ATrEqndEpiL59Nej5uUGQURYijJfHep08AMb4rCxvIZATVm1Ocxu48rGCGolv8jZFJzSJq84HCrVRKMw== nicola@p1" // this is testPubKey signed using testCAUserKey but expired. // % ssh-keygen -s ca_user_key -I test_user_sftp -n test_user_sftp -V 20100101123000:20110101123000 -z 4 /tmp/test.pub testCertExpired = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgU3TLP5285k20fBSsdZioI78oJUpaRXFlgx5IPg6gWg8AAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAABAAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAEs93LgAAAAATR8QOAAAAAAAAACCAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFnBlcm1pdC1wb3J0LWZvcndhcmRpbmcAAAAAAAAACnBlcm1pdC1wdHkAAAAAAAAADnBlcm1pdC11c2VyLXJjAAAAAAAAAAAAAAGXAAAAB3NzaC1yc2EAAAADAQABAAABgQDF5fcwZHiyixmnE6IlOZJpZhWXoh62gN+yadAA0GJ509SAEaZVLPDP8S5RsE8mUikR3wxynVshxHeqMhrkS+RlNbhSlOXDdNg94yTrq/xF8Z/PgKRInvef74k5i7bAIytza7jERzFJ/ujTEy3537T5k5EYQJ15ZQGuvzynSdv+6o99SjI4jFplyQOZ2QcYbEAmhHm5GgQlIiEFG/RlDtLksOulKZxOY3qPzP0AyQxtZJXn/5vG40aW9LTbwxCJqWlgrkFXMqAAVCbuU5YspwhiXmKt1PsldiXw23oloa4caCKN1jzbFiGuZNXEU2Ebx7JIvjQCPaUYwLjEbkRDxDqN/vmwZqBuKYiuG9Eafx+nFSQkr7QYb5b+mT+/1IFHnmeRGn38731kBqtH7tpzC/t+soRX9p2HtJM+9MYhblO2OqTSPGTlxihWUkyiRBekpAhaiHld16TsG+A3bOJHrojGcX+5g6oGarKGLAMcykL1X+rZqT993Mo6d2Z7q43MOXEAAAGUAAAADHJzYS1zaGEyLTUxMgAAAYAlH3hhj8J6xLyVpeLZjblzwDKrxp/MWiH30hQ965ExPrPRcoAZFEKVqOYdj6bp4Q19Q4Yzqdobg3aN5ym2iH0b2TlOY0mM901CAoHbNJyiLs+0KiFRoJ+30EDj/hcKusg6v8ln2yixPagAyQu3zyiWo4t1ZuO3I86xchGlptStxSdHAHPFCfpbhcnzWFZctiMqUutl82C4ROWyjOZcRzdVdWHeN5h8wnooXuvba2VkT8QPmjYYyRGuQ3Hg+ySdh8Tel4wiix1Dg5MX7Wjh4hKEx80No9UPy+0iyZMNc07lsWAtrY6NRxGM5CzB6mklscB8TzFrVSnIl9u3bquLfaCrFt/Mft5dR7Yy4jmF+zUhjia6h6giCZ91J+FZ4hV+WkBtPCvTfrGWoA1BgEB/iI2xOq/NPqJ7UXRoMXk/l0NPgRPT2JS1adegqnt4ddr6IlmPyZxaSEvXhanjKdfMlEFYO1wz7ouqpYUozQVy4KXBlzFlNwyD1hI+k4+/A6AIYeI= nicola@p1" // this is testPubKey signed without a principal // ssh-keygen -s ca_user_key -I test_user_sftp -V always:forever -O source-address=127.0.0.1 -z 1 /tmp/test.pub testCertNoPrincipals = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg2Bx0s8nafJtriqoBuQfbFByhdQMkjDIZhV90JZSGN8AAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAQAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAAAAAAAAAAAAAD//////////wAAACMAAAAOc291cmNlLWFkZHJlc3MAAAANAAAACTEyNy4wLjAuMQAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAZcAAAAHc3NoLXJzYQAAAAMBAAEAAAGBAMXl9zBkeLKLGacToiU5kmlmFZeiHraA37Jp0ADQYnnT1IARplUs8M/xLlGwTyZSKRHfDHKdWyHEd6oyGuRL5GU1uFKU5cN02D3jJOur/EXxn8+ApEie95/viTmLtsAjK3NruMRHMUn+6NMTLfnftPmTkRhAnXllAa6/PKdJ2/7qj31KMjiMWmXJA5nZBxhsQCaEebkaBCUiIQUb9GUO0uSw66UpnE5jeo/M/QDJDG1klef/m8bjRpb0tNvDEImpaWCuQVcyoABUJu5TliynCGJeYq3U+yV2JfDbeiWhrhxoIo3WPNsWIa5k1cRTYRvHski+NAI9pRjAuMRuREPEOo3++bBmoG4piK4b0Rp/H6cVJCSvtBhvlv6ZP7/UgUeeZ5EaffzvfWQGq0fu2nML+36yhFf2nYe0kz70xiFuU7Y6pNI8ZOXGKFZSTKJEF6SkCFqIeV3XpOwb4Dds4keuiMZxf7mDqgZqsoYsAxzKQvVf6tmpP33cyjp3Znurjcw5cQAAAZQAAAAMcnNhLXNoYTItNTEyAAABgHgax/++NA5YZXDHH180BcQtDBve8Vc+XJzqQUe8xBiqd+KJnas6He7vW62qMaAfu63i0Uycj2Djfjy5dyx1GB9wup8YuP5mXlmJTx+7UPPjwbfrZWtk8iJ7KhFAwjh0KRZD4uIvoeecK8QE9zh64k2LNVqlWbFTdoPulRC29cGcXDpMU2eToFEyWbceHOZyyifXf98ZMZbaQzWzwSZ5rFucJ1b0aeT6aAJWB+Dq7mIQWf/jCWr8kNaeCzMKJsFQkQEfmHls29ChV92sNRhngUDxll0Ir0wpPea1fFEBnUhLRTLC8GhDDbWAzsZtXqx9fjoAkb/gwsU6TGxevuOMxEABjDA9PyJiTXJI9oTUCwDIAUVVFLsCEum3o/BblngXajUGibaif5ZSKBocpP70oTeAngQYB7r1/vquQzGsGFhTN4FUXLSpLu9Zqi1z58/qa7SgKSfNp98X/4zrhltAX73ZEvg0NUMv2HwlwlqHdpF3FYolAxInp7c2jBTncQ2l3w== nicola@p1" osWindows = "windows" testFileName = "test_file_sftp.dat" testDLFileName = "test_download_sftp.dat" ) var ( configDir = filepath.Join(".", "..", "..") allPerms = []string{dataprovider.PermAny} homeBasePath string scpPath string scpForce bool gitPath string sshPath string hookCmdPath string pubKeyPath string privateKeyPath string trustedCAUserKey string revokeUserCerts string gitWrapPath string extAuthPath string keyIntAuthPath string preLoginPath string postConnectPath string preDownloadPath string preUploadPath string checkPwdPath string logFilePath string hostKeyFPs []string ) func TestMain(m *testing.M) { logFilePath = filepath.Join(configDir, "sftpgo_sftpd_test.log") loginBannerFileName := "login_banner" loginBannerFile := filepath.Join(configDir, loginBannerFileName) logger.InitLogger(logFilePath, 10, 1, 28, false, false, zerolog.DebugLevel) err := os.WriteFile(loginBannerFile, []byte("simple login banner\n"), os.ModePerm) if err != nil { logger.ErrorToConsole("error creating login banner: %v", err) } os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2") os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") err = config.LoadConfig(configDir, "") if err != nil { logger.ErrorToConsole("error loading configuration: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() logger.InfoToConsole("Starting SFTPD tests, provider: %v", providerConf.Driver) commonConf := config.GetCommonConfig() homeBasePath = os.TempDir() checkSystemCommands() var scriptArgs string if runtime.GOOS == osWindows { scriptArgs = "%*" } else { commonConf.Actions.ExecuteOn = []string{"download", "upload", "rename", "delete", "ssh_cmd", "pre-download", "pre-upload"} commonConf.Actions.Hook = hookCmdPath scriptArgs = "$@" } err = dataprovider.Initialize(providerConf, configDir, true) if err != nil { logger.ErrorToConsole("error initializing data provider: %v", err) os.Exit(1) } err = dataprovider.UpdateConfigs(nil, "", "", "") if err != nil { logger.ErrorToConsole("error resetting configs: %v", err) os.Exit(1) } err = common.Initialize(commonConf, 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) } httpConfig := config.GetHTTPConfig() httpConfig.Initialize(configDir) //nolint:errcheck kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing kms: %v", err) os.Exit(1) } mfaConfig := config.GetMFAConfig() err = mfaConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing MFA: %v", err) os.Exit(1) } sftpdConf := config.GetSFTPDConfig() httpdConf := config.GetHTTPDConfig() sftpdConf.Bindings = []sftpd.Binding{ { Port: 2022, ApplyProxyConfig: true, }, } sftpdConf.KexAlgorithms = []string{"curve25519-sha256@libssh.org", ssh.KeyExchangeECDHP256, ssh.KeyExchangeECDHP384} sftpdConf.Ciphers = []string{ssh.CipherChaCha20Poly1305, ssh.CipherAES128GCM, ssh.CipherAES256CTR} sftpdConf.LoginBannerFile = loginBannerFileName // we need to test all supported ssh commands sftpdConf.EnabledSSHCommands = []string{"*"} keyIntAuthPath = filepath.Join(homeBasePath, "keyintauth.sh") err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing keyboard interactive script: %v", err) os.Exit(1) } sftpdConf.KeyboardInteractiveAuthentication = true sftpdConf.KeyboardInteractiveHook = keyIntAuthPath createInitialFiles(scriptArgs) sftpdConf.TrustedUserCAKeys = append(sftpdConf.TrustedUserCAKeys, trustedCAUserKey) sftpdConf.RevokedUserCertsFile = revokeUserCerts go func(cfg sftpd.Configuration) { logger.Debug(logSender, "", "initializing SFTP server with config %+v", sftpdConf) if err := cfg.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start SFTP server: %v", err) os.Exit(1) } }(sftpdConf) go func() { if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } }() waitTCPListening(sftpdConf.Bindings[0].GetAddress()) waitTCPListening(httpdConf.Bindings[0].GetAddress()) sftpdConf.Bindings = []sftpd.Binding{ { Port: 2222, ApplyProxyConfig: true, }, } sftpdConf.PasswordAuthentication = false common.Config.ProxyProtocol = 1 go func(cfg sftpd.Configuration) { logger.Debug(logSender, "", "initializing SFTP server with config %+v and proxy protocol %v", sftpdConf, common.Config.ProxyProtocol) if err := cfg.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start SFTP server with proxy protocol 1: %v", err) os.Exit(1) } }(sftpdConf) waitTCPListening(sftpdConf.Bindings[0].GetAddress()) sftpdConf.Bindings = []sftpd.Binding{ { Port: 2226, ApplyProxyConfig: false, }, } sftpdConf.PasswordAuthentication = true go func(cfg sftpd.Configuration) { logger.Debug(logSender, "", "initializing SFTP server with config %+v and proxy protocol %v", cfg, common.Config.ProxyProtocol) if err := cfg.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start SFTP server with proxy protocol 2: %v", err) os.Exit(1) } }(sftpdConf) waitTCPListening(sftpdConf.Bindings[0].GetAddress()) sftpdConf.Bindings = []sftpd.Binding{ { Port: 2224, ApplyProxyConfig: true, }, } sftpdConf.PasswordAuthentication = true common.Config.ProxyProtocol = 2 go func() { logger.Debug(logSender, "", "initializing SFTP server with config %+v and proxy protocol %v", sftpdConf, common.Config.ProxyProtocol) if err := sftpdConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start SFTP server with proxy protocol 2: %v", err) os.Exit(1) } }() waitTCPListening(sftpdConf.Bindings[0].GetAddress()) getHostKeysFingerprints(sftpdConf.HostKeys) startHTTPFs() exitCode := m.Run() os.Remove(logFilePath) os.Remove(loginBannerFile) os.Remove(pubKeyPath) os.Remove(privateKeyPath) os.Remove(trustedCAUserKey) os.Remove(revokeUserCerts) os.Remove(gitWrapPath) os.Remove(extAuthPath) os.Remove(preLoginPath) os.Remove(postConnectPath) os.Remove(preDownloadPath) os.Remove(preUploadPath) os.Remove(keyIntAuthPath) os.Remove(checkPwdPath) os.Exit(exitCode) } func TestInitialization(t *testing.T) { err := config.LoadConfig(configDir, "") assert.NoError(t, err) sftpdConf := config.GetSFTPDConfig() sftpdConf.Bindings = []sftpd.Binding{ { Port: 2022, ApplyProxyConfig: true, }, { Port: 0, }, } sftpdConf.LoginBannerFile = "invalid_file" sftpdConf.EnabledSSHCommands = append(sftpdConf.EnabledSSHCommands, "ls") err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.KeyboardInteractiveAuthentication = true sftpdConf.KeyboardInteractiveHook = "invalid_file" err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.KeyboardInteractiveAuthentication = true sftpdConf.KeyboardInteractiveHook = filepath.Join(homeBasePath, "invalid_file") err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.KeyboardInteractiveAuthentication = false err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.Bindings = []sftpd.Binding{ { Port: 4444, ApplyProxyConfig: true, }, } common.Config.ProxyProtocol = 1 assert.True(t, sftpdConf.Bindings[0].HasProxy()) common.Config.ProxyProtocol = 0 sftpdConf.HostKeys = []string{"missing key"} err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.HostKeys = nil sftpdConf.TrustedUserCAKeys = []string{"missing ca key"} err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.Bindings = nil err = sftpdConf.Initialize(configDir) assert.EqualError(t, err, common.ErrNoBinding.Error()) sftpdConf = config.GetSFTPDConfig() sftpdConf.Ciphers = []string{"not a cipher"} err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported cipher") } sftpdConf.Ciphers = nil sftpdConf.MACs = []string{"not a MAC"} err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported MAC algorithm") } sftpdConf.MACs = nil sftpdConf.KexAlgorithms = []string{"diffie-hellman-group-exchange-sha1", "not a KEX"} err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported key-exchange algorithm") } sftpdConf.KexAlgorithms = nil sftpdConf.PublicKeyAlgorithms = []string{"not a pub key algo"} err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported public key authentication algorithm") } sftpdConf.PublicKeyAlgorithms = nil sftpdConf.HostKeyAlgorithms = []string{"not a host key algo"} err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported host key algorithm") } sftpdConf.HostKeyAlgorithms = nil sftpdConf.HostCertificates = []string{"missing file"} err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to load host certificate") } sftpdConf.HostCertificates = []string{"."} err = sftpdConf.Initialize(configDir) assert.Error(t, err) hostCertPath := filepath.Join(os.TempDir(), "host_cert.pub") err = os.WriteFile(hostCertPath, []byte(testCertValid), 0600) assert.NoError(t, err) sftpdConf.HostKeys = []string{privateKeyPath} sftpdConf.HostCertificates = []string{hostCertPath} err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is not an host certificate") } err = os.WriteFile(hostCertPath, []byte(testPubKey), 0600) assert.NoError(t, err) err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is not an SSH certificate") } err = os.WriteFile(hostCertPath, []byte("abc"), 0600) assert.NoError(t, err) err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to parse host certificate") } err = os.WriteFile(hostCertPath, []byte(testHostCert), 0600) assert.NoError(t, err) err = sftpdConf.Initialize(configDir) assert.Error(t, err) err = os.Remove(hostCertPath) assert.NoError(t, err) sftpdConf.HostKeys = nil sftpdConf.HostCertificates = nil sftpdConf.OPKSSHPath = "relative path" err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.OPKSSHPath = filepath.Join(os.TempDir(), "missing path") err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.OPKSSHChecksum = "invalid checksum" err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.OPKSSHPath = "" sftpdConf.OPKSSHChecksum = "" sftpdConf.RevokedUserCertsFile = "." err = sftpdConf.Initialize(configDir) assert.Error(t, err) sftpdConf.RevokedUserCertsFile = "a missing file" err = sftpdConf.Initialize(configDir) assert.ErrorIs(t, err, os.ErrNotExist) err = createTestFile(revokeUserCerts, 10*1024*1024) assert.NoError(t, err) sftpdConf.RevokedUserCertsFile = revokeUserCerts err = sftpdConf.Initialize(configDir) assert.Error(t, err) err = os.WriteFile(revokeUserCerts, []byte(`[]`), 0644) assert.NoError(t, err) err = sftpdConf.Initialize(configDir) assert.Error(t, err) err = dataprovider.Close() assert.NoError(t, err) err = sftpdConf.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to load configs from provider") } err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestBasicSFTPHandling(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.QuotaSize = 6553600 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) expectedQuotaSize := user.UsedQuotaSize + testFileSize expectedQuotaFiles := user.UsedQuotaFiles + 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, int64(0), user.FirstDownload) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.FirstUpload, int64(0)) assert.Equal(t, int64(0), user.FirstDownload) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) err = client.Remove(testFileName) assert.NoError(t, err) _, err = client.Lstat(testFileName) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize-testFileSize, user.UsedQuotaSize) assert.Greater(t, user.FirstUpload, int64(0)) assert.Greater(t, user.FirstDownload, int64(0)) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } u.Username = "missing user" _, _, err = getSftpClient(u, false) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) status := sftpd.GetStatus() assert.True(t, status.IsActive) sshCommands := status.GetSSHCommandsAsString() assert.NotEmpty(t, sshCommands) sshAuths := status.GetSupportedAuthsAsString() assert.NotEmpty(t, sshAuths) assert.NotEmpty(t, status.HostKeys[0].GetAlgosAsString()) assert.NotEmpty(t, status.GetMACsAsString()) assert.NotEmpty(t, status.GetKEXsAsString()) assert.NotEmpty(t, status.GetCiphersAsString()) assert.NotEmpty(t, status.GetPublicKeysAlgosAsString()) } func TestBasicSFTPFsHandling(t *testing.T) { usePubKey := true baseUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) u := getTestSFTPUser(usePubKey) u.QuotaSize = 6553600 u.FsConfig.SFTPConfig.DisableCouncurrentReads = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) testLinkName := testFileName + ".link" testLinkToLinkName := testLinkName + ".link" expectedQuotaSize := testFileSize expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) err = client.Symlink(testFileName, testLinkName) assert.NoError(t, err) info, err := client.Lstat(testLinkName) if assert.NoError(t, err) { assert.True(t, info.Mode()&os.ModeSymlink != 0) } info, err = client.Stat(testLinkName) if assert.NoError(t, err) { assert.True(t, info.Mode()&os.ModeSymlink == 0) } val, err := client.ReadLink(testLinkName) if assert.NoError(t, err) { assert.Equal(t, path.Join("/", testFileName), val) } linkDir := "linkDir" err = client.Mkdir(linkDir) assert.NoError(t, err) linkToLinkPath := path.Join(linkDir, testLinkToLinkName) err = client.Symlink(path.Join("/", testLinkName), linkToLinkPath) assert.NoError(t, err) info, err = client.Lstat(linkToLinkPath) if assert.NoError(t, err) { assert.True(t, info.Mode()&os.ModeSymlink != 0) } info, err = client.Stat(linkToLinkPath) if assert.NoError(t, err) { assert.True(t, info.Mode()&os.ModeSymlink == 0) } val, err = client.ReadLink(linkToLinkPath) if assert.NoError(t, err) { assert.Equal(t, path.Join("/", testLinkName), val) } err = client.Remove(linkToLinkPath) assert.NoError(t, err) err = client.RemoveDirectory(linkDir) assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) _, err = client.Lstat(testFileName) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) // now overwrite the symlink err = sftpUploadFile(testFilePath, testLinkName, testFileSize, client) assert.NoError(t, err) contents, err := client.ReadDir("/") if assert.NoError(t, err) { assert.Len(t, contents, 1) assert.Equal(t, testFileSize, contents[0].Size()) assert.Equal(t, testLinkName, contents[0].Name()) assert.False(t, contents[0].IsDir()) assert.True(t, contents[0].Mode().IsRegular()) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) stat, err := client.StatVFS("/") assert.NoError(t, err) assert.Equal(t, uint64(u.QuotaSize/4096), stat.Blocks) assert.Equal(t, uint64((u.QuotaSize-testFileSize)/4096), stat.Bfree) assert.Equal(t, uint64(1), stat.Files-stat.Ffree) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) } func TestSFTPFsPasswordProtectedPrivateKey(t *testing.T) { usePubKey := false u := getTestUser(true) u.PublicKeys = []string{testPubKeyPwd} baseUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(testPrivateKeyPwd) u.FsConfig.SFTPConfig.KeyPassphrase = kms.NewPlainSecret(privateKeyPwd) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } // update the user, the key must be preserved _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) } func TestSFTPFsEscapeHomeDir(t *testing.T) { usePubKey := true baseUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) u := getTestSFTPUser(usePubKey) sftpPrefix := "/prefix" u.FsConfig.SFTPConfig.Prefix = sftpPrefix user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) dirName := "dir" linkName := "link" err := client.Mkdir(dirName) assert.NoError(t, err) err = os.Symlink(baseUser.GetHomeDir(), filepath.Join(baseUser.GetHomeDir(), sftpPrefix, dirName, linkName)) assert.NoError(t, err) err = os.Symlink(filepath.Join(baseUser.GetHomeDir(), sftpPrefix, dirName, linkName), filepath.Join(baseUser.GetHomeDir(), sftpPrefix, linkName)) assert.NoError(t, err) // linkName points to a link inside the home dir and this link points to a dir outside the home dir _, err = client.ReadLink(linkName) assert.ErrorIs(t, err, os.ErrPermission) _, err = client.RealPath(linkName) assert.ErrorIs(t, err, os.ErrPermission) _, err = client.ReadDir(linkName) assert.ErrorIs(t, err, os.ErrPermission) _, err = client.ReadDir(path.Join(dirName, linkName)) assert.ErrorIs(t, err, os.ErrPermission) _, err = client.ReadDir("/") assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) } func TestReadDirLongNames(t *testing.T) { usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() numFiles := 1000 for i := 0; i < 1000; i++ { fPath := filepath.Join(user.GetHomeDir(), hex.EncodeToString(util.GenerateRandomBytes(127))) err = os.WriteFile(fPath, util.GenerateRandomBytes(30), 0666) assert.NoError(t, err) } entries, err := client.ReadDir("/") assert.NoError(t, err) assert.Len(t, entries, numFiles) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestGroupSettingsOverride(t *testing.T) { usePubKey := true g := getTestGroup() g.UserSettings.Filters.StartDirectory = "/%username%" group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) u := getTestUser(usePubKey) u.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() currentDir, err := client.Getwd() assert.NoError(t, err) assert.Equal(t, "/"+user.Username, currentDir) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestStartDirectory(t *testing.T) { usePubKey := false startDir := "/st@ rt/dir" u := getTestUser(usePubKey) u.Filters.StartDirectory = startDir localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.Filters.StartDirectory = startDir sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() currentDir, err := client.Getwd() assert.NoError(t, err) assert.Equal(t, startDir, currentDir) entries, err := client.ReadDir(".") assert.NoError(t, err) assert.Len(t, entries, 0) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) _, err = client.Stat(testFileName) assert.NoError(t, err) err = client.Rename(testFileName, testFileName+"_rename") assert.NoError(t, err) entries, err = client.ReadDir(".") assert.NoError(t, err) assert.Len(t, entries, 1) currentDir, err = client.RealPath("..") assert.NoError(t, err) assert.Equal(t, path.Dir(startDir), currentDir) currentDir, err = client.RealPath("../..") assert.NoError(t, err) assert.Equal(t, "/", currentDir) currentDir, err = client.RealPath("../../..") assert.NoError(t, err) assert.Equal(t, "/", currentDir) err = client.Remove(testFileName + "_rename") assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestLoginNonExistentUser(t *testing.T) { usePubKey := true user := getTestUser(usePubKey) _, _, err := getSftpClient(user, usePubKey) assert.Error(t, err) } func TestRateLimiter(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.RateLimitersConfig = []common.RateLimiterConfig{ { Average: 1, Period: 1000, Burst: 1, Type: 1, Protocols: []string{common.ProtocolSSH}, }, } err := common.Initialize(cfg, 0) assert.NoError(t, err) usePubKey := false user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestDefender(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 cfg.DefenderConfig.ScoreValid = 1 err := common.Initialize(cfg, 0) assert.NoError(t, err) usePubKey := false user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } user.Password = "wrong_pwd" _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) if assert.Len(t, hosts, 1) { host := hosts[0] assert.Empty(t, host.GetBanTime()) assert.Equal(t, 1, host.Score) } for i := 0; i < 2; i++ { _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) } user.Password = defaultPassword _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestOpenReadWrite(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.QuotaSize = 6553600 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.QuotaSize = 6553600 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("sample test data") n, err := sftpFile.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) buffer := make([]byte, 128) n, err = sftpFile.ReadAt(buffer, 1) assert.EqualError(t, err, io.EOF.Error()) assert.Equal(t, len(testData)-1, n) assert.Equal(t, testData[1:], buffer[:n]) err = sftpFile.Close() assert.NoError(t, err) } sftpFile, err = client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("new test data") n, err := sftpFile.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) buffer := make([]byte, 128) n, err = sftpFile.ReadAt(buffer, 1) assert.EqualError(t, err, io.EOF.Error()) assert.Equal(t, len(testData)-1, n) assert.Equal(t, testData[1:], buffer[:n]) err = sftpFile.Close() assert.NoError(t, err) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestOpenReadWritePerm(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) // we cannot read inside "/sub", rename is needed otherwise the atomic upload will fail for the sftpfs user u.Permissions["/sub"] = []string{dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermRename} localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.Permissions["/sub"] = []string{dataprovider.PermUpload, dataprovider.PermListItems} sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("sub") assert.NoError(t, err) sftpFileName := path.Join("sub", "file.txt") sftpFile, err := client.OpenFile(sftpFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("test data") n, err := sftpFile.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) buffer := make([]byte, 128) _, err = sftpFile.ReadAt(buffer, 1) if assert.Error(t, err) { assert.Contains(t, strings.ToLower(err.Error()), "permission denied") } err = sftpFile.Close() assert.NoError(t, err) } if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestConcurrency(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 0 usePubKey := true numLogins := 50 u := getTestUser(usePubKey) u.QuotaFiles = numLogins + 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) var wg sync.WaitGroup testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(262144) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) var closedConns atomic.Int32 for i := 0; i < numLogins; i++ { wg.Add(1) go func(counter int) { defer wg.Done() defer closedConns.Add(1) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client) assert.NoError(t, err) assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0) client.Close() conn.Close() } }(i) } wg.Add(1) go func() { defer wg.Done() maxConns := 0 maxSessions := 0 for { servedReqs := closedConns.Load() if servedReqs > 0 { stats := common.Connections.GetStats("") if len(stats) > maxConns { maxConns = len(stats) } activeSessions := common.Connections.GetActiveSessions(defaultUsername) if activeSessions > maxSessions { maxSessions = activeSessions } } if servedReqs >= int32(numLogins) { break } time.Sleep(1 * time.Millisecond) } assert.Greater(t, maxConns, 0) assert.Greater(t, maxSessions, 0) }() wg.Wait() conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { files, err := client.ReadDir(".") assert.NoError(t, err) assert.Len(t, files, numLogins) client.Close() conn.Close() } assert.Eventually(t, func() bool { return common.Connections.GetActiveSessions(defaultUsername) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.MaxPerHostConnections = oldValue } func TestProxyProtocol(t *testing.T) { usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) // remove the home dir to test auto creation err = os.RemoveAll(user.HomeDir) assert.NoError(t, err) conn, client, err := getSftpClientWithAddr(user, usePubKey, sftpSrvAddr2222) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, _, err = getSftpClientWithAddr(user, usePubKey, "127.0.0.1:2224") assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRealPath(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { p, err := client.RealPath("../..") assert.NoError(t, err) assert.Equal(t, "/", p) p, err = client.RealPath("../test") assert.NoError(t, err) assert.Equal(t, "/test", p) subdir := "testsubdir" err = client.Mkdir(subdir) assert.NoError(t, err) linkName := testFileName + "_link" err = client.Symlink(path.Join("/", testFileName), path.Join(subdir, linkName)) assert.NoError(t, err) p, err = client.RealPath(path.Join(subdir, linkName)) assert.NoError(t, err) assert.Equal(t, path.Join("/", testFileName), p) // an existing path sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("hello world") n, err := sftpFile.WriteAt(testData, 0) assert.NoError(t, err) assert.Equal(t, len(testData), n) } p, err = client.RealPath(path.Join(subdir, linkName)) assert.NoError(t, err) assert.Equal(t, path.Join("/", testFileName), p) // now a link outside the home dir err = os.Symlink(filepath.Clean(os.TempDir()), filepath.Join(localUser.GetHomeDir(), subdir, "temp")) assert.NoError(t, err) _, err = client.RealPath(path.Join(subdir, "temp")) assert.ErrorIs(t, err, os.ErrPermission) conn.Close() client.Close() err = os.Remove(filepath.Join(localUser.GetHomeDir(), subdir, "temp")) assert.NoError(t, err) if user.Username == localUser.Username { err = os.RemoveAll(filepath.Join(localUser.GetHomeDir(), subdir)) assert.NoError(t, err) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestBufferedSFTP(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.FsConfig.SFTPConfig.BufferSize = 2 u.HomeDir = filepath.Join(os.TempDir(), u.Username) sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(sftpUser, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) appendDataSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) initialHash, err := computeHashForFile(sha256.New(), testFilePath) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = appendToTestFile(testFilePath, appendDataSize) assert.NoError(t, err) err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, false, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) downloadedFileHash, err := computeHashForFile(sha256.New(), localDownloadPath) assert.NoError(t, err) assert.Equal(t, initialHash, downloadedFileHash) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("sample test sftp data") n, err := sftpFile.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) err = sftpFile.Truncate(0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = sftpFile.Truncate(4) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } buffer := make([]byte, 128) _, err = sftpFile.Read(buffer) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") } err = sftpFile.Close() assert.NoError(t, err) info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, int64(len(testData)), info.Size()) } } // test WriteAt sftpFile, err = client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("hello world") n, err := sftpFile.WriteAt(testData[:6], 0) assert.NoError(t, err) assert.Equal(t, 6, n) n, err = sftpFile.WriteAt(testData[6:], 6) assert.NoError(t, err) assert.Equal(t, 5, n) err = sftpFile.Close() assert.NoError(t, err) info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, int64(len(testData)), info.Size()) } } // test ReadAt sftpFile, err = client.OpenFile(testFileName, os.O_RDONLY) if assert.NoError(t, err) { buffer := make([]byte, 128) n, err := sftpFile.ReadAt(buffer, 6) assert.ErrorIs(t, err, io.EOF) assert.Equal(t, 5, n) assert.Equal(t, []byte("world"), buffer[:n]) err = sftpFile.Close() assert.NoError(t, err) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(sftpUser.GetHomeDir()) assert.NoError(t, err) } func TestUploadResume(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestUser(usePubKey) u.FsConfig.OSConfig = sdk.OSFsConfig{ WriteBufferSize: 1, ReadBufferSize: 1, } u.Username += "_buffered" u.HomeDir += "_with_buf" bufferedUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser, bufferedUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) appendDataSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = appendToTestFile(testFilePath, appendDataSize) assert.NoError(t, err) err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, false, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize+appendDataSize, client) assert.NoError(t, err) initialHash, err := computeHashForFile(sha256.New(), testFilePath) assert.NoError(t, err) downloadedFileHash, err := computeHashForFile(sha256.New(), localDownloadPath) assert.NoError(t, err) assert.Equal(t, initialHash, downloadedFileHash) err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, true, client) assert.Error(t, err, "resume uploading file with invalid offset must fail") err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(bufferedUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(bufferedUser.GetHomeDir()) assert.NoError(t, err) } func TestDirCommands(t *testing.T) { usePubKey := false user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) // remove the home dir to test auto creation err = os.RemoveAll(user.HomeDir) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("test1") assert.NoError(t, err) err = client.Rename("test1", "test") assert.NoError(t, err) // rename a missing file err = client.Rename("test1", "test2") assert.Error(t, err) _, err = client.Lstat("/test1") assert.Error(t, err, "stat for renamed dir must not succeed") err = client.PosixRename("test", "test1") assert.NoError(t, err) err = client.Remove("test1") assert.NoError(t, err) err = client.Mkdir("/test/test1") assert.Error(t, err, "recursive mkdir must fail") err = client.Mkdir("/test") assert.NoError(t, err) err = client.Mkdir("/test/test1") assert.NoError(t, err) _, err = client.ReadDir("/this/dir/does/not/exist") assert.Error(t, err, "reading a missing dir must fail") err = client.RemoveDirectory("/test/test1") assert.NoError(t, err) err = client.RemoveDirectory("/test") assert.NoError(t, err) _, err = client.Lstat("/test") assert.Error(t, err, "stat for deleted dir must not succeed") _, err = client.Stat("/test") assert.Error(t, err, "stat for deleted dir must not succeed") err = client.RemoveDirectory("/test") assert.Error(t, err, "remove missing path must fail") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRemove(t *testing.T) { usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("test") assert.NoError(t, err) err = client.Mkdir("/test/test1") assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join("/test", testFileName), testFileSize, client) assert.NoError(t, err) err = client.Remove("/test") assert.Error(t, err, "remove non empty dir must fail") err = client.RemoveDirectory(path.Join("/test", testFileName)) assert.Error(t, err, "remove a file with rmdir must fail") err = client.Remove(path.Join("/test", testFileName)) assert.NoError(t, err) err = client.Remove(path.Join("/test", testFileName)) assert.Error(t, err, "remove missing file must fail") err = client.Remove("/test/test1") assert.NoError(t, err) err = client.Remove("/test") assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLink(t *testing.T) { usePubKey := false user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Symlink(testFileName, testFileName+".link") assert.NoError(t, err) linkName, err := client.ReadLink(testFileName + ".link") assert.NoError(t, err) assert.Equal(t, path.Join("/", testFileName), linkName) err = client.Symlink(testFileName, testFileName+".link") assert.Error(t, err, "creating a symlink to an existing one must fail") err = client.Link(testFileName, testFileName+".hlink") assert.Error(t, err, "hard link is not supported and must fail") err = client.Remove(testFileName + ".link") assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestStat(t *testing.T) { usePubKey := false localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) _, err := client.Lstat(testFileName) assert.NoError(t, err) _, err = client.Stat(testFileName) assert.NoError(t, err) // stat a missing path we should get an fs.ErrNotExist error _, err = client.Stat("missing path") assert.True(t, errors.Is(err, fs.ErrNotExist)) _, err = client.Lstat("missing path") assert.True(t, errors.Is(err, fs.ErrNotExist)) // mode 0666 and 0444 works on Windows too newPerm := os.FileMode(0666) err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) newFi, err := client.Lstat(testFileName) assert.NoError(t, err) assert.Equal(t, newPerm, newFi.Mode().Perm()) newPerm = os.FileMode(0444) err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) newFi, err = client.Lstat(testFileName) if assert.NoError(t, err) { assert.Equal(t, newPerm, newFi.Mode().Perm()) } _, err = client.ReadLink(testFileName) assert.Error(t, err, "readlink on a file must fail") symlinkName := testFileName + ".sym" err = client.Symlink(testFileName, symlinkName) assert.NoError(t, err) info, err := client.Lstat(symlinkName) if assert.NoError(t, err) { assert.True(t, info.Mode()&os.ModeSymlink != 0) } info, err = client.Stat(symlinkName) if assert.NoError(t, err) { assert.False(t, info.Mode()&os.ModeSymlink != 0) } linkName, err := client.ReadLink(symlinkName) assert.NoError(t, err) assert.Equal(t, path.Join("/", testFileName), linkName) newPerm = os.FileMode(0666) err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) err = client.Truncate(testFileName, 100) assert.NoError(t, err) fi, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, int64(100), fi.Size()) } f, err := client.OpenFile(testFileName, os.O_WRONLY) if assert.NoError(t, err) { err = f.Truncate(5) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) } f, err = client.OpenFile(testFileName, os.O_WRONLY) newPerm = os.FileMode(0444) if assert.NoError(t, err) { err = f.Chmod(newPerm) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) } newFi, err = client.Lstat(testFileName) if assert.NoError(t, err) { assert.Equal(t, newPerm, newFi.Mode().Perm()) } newPerm = os.FileMode(0666) err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestStatChownChmod(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("chown is not supported on Windows, chmod is partially supported") } usePubKey := true localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Chown(testFileName, os.Getuid(), os.Getgid()) assert.NoError(t, err) newPerm := os.FileMode(0600) err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) newFi, err := client.Lstat(testFileName) assert.NoError(t, err) assert.Equal(t, newPerm, newFi.Mode().Perm()) err = client.Remove(testFileName) assert.NoError(t, err) err = client.Chmod(testFileName, newPerm) assert.EqualError(t, err, os.ErrNotExist.Error()) err = client.Chown(testFileName, os.Getuid(), os.Getgid()) assert.EqualError(t, err, os.ErrNotExist.Error()) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestSFTPFsLoginWrongFingerprint(t *testing.T) { usePubKey := true localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(sftpUser, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } sftpUser.FsConfig.SFTPConfig.Fingerprints = append(sftpUser.FsConfig.SFTPConfig.Fingerprints, "wrong") _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(sftpUser, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } out, err := runSSHCommand("md5sum", sftpUser, usePubKey) assert.NoError(t, err) assert.Contains(t, string(out), "d41d8cd98f00b204e9800998ecf8427e") sftpUser.FsConfig.SFTPConfig.Fingerprints = []string{"wrong"} _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(sftpUser, usePubKey) if !assert.Error(t, err) { defer conn.Close() defer client.Close() } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestChtimes(t *testing.T) { usePubKey := false localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) testDir := "test" //nolint:goconst err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) acmodTime := time.Now() err = client.Chtimes(testFileName, acmodTime, acmodTime) assert.NoError(t, err) newFi, err := client.Lstat(testFileName) assert.NoError(t, err) diff := math.Abs(newFi.ModTime().Sub(acmodTime).Seconds()) assert.LessOrEqual(t, diff, float64(1)) err = client.Chtimes("invalidFile", acmodTime, acmodTime) assert.EqualError(t, err, os.ErrNotExist.Error()) err = client.Mkdir(testDir) assert.NoError(t, err) err = client.Chtimes(testDir, acmodTime, acmodTime) assert.NoError(t, err) newFi, err = client.Lstat(testDir) assert.NoError(t, err) diff = math.Abs(newFi.ModTime().Sub(acmodTime).Seconds()) assert.LessOrEqual(t, diff, float64(1)) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } // basic tests to verify virtual chroot, should be improved to cover more cases ... func TestEscapeHomeDir(t *testing.T) { usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) dirOutsideHome := filepath.Join(homeBasePath, defaultUsername+"1", "dir") err = os.MkdirAll(dirOutsideHome, os.ModePerm) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) testDir := "testDir" //nolint:goconst linkPath := filepath.Join(homeBasePath, defaultUsername, testDir) err = os.Symlink(homeBasePath, linkPath) assert.NoError(t, err) _, err = client.ReadDir(testDir) assert.Error(t, err, "reading a symbolic link outside home dir should not succeeded") err = os.Remove(linkPath) assert.NoError(t, err) err = os.Symlink(dirOutsideHome, linkPath) assert.NoError(t, err) _, err := client.ReadDir(testDir) assert.Error(t, err, "reading a symbolic link outside home dir should not succeeded") err = client.Chmod(path.Join(testDir, "sub", "dir"), os.ModePerm) assert.ErrorIs(t, err, os.ErrPermission) assert.Error(t, err, "setstat on a file outside home dir must fail") testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteDestPath := path.Join("..", "..", testFileName) err = sftpUploadFile(testFilePath, remoteDestPath, testFileSize, client) assert.NoError(t, err) _, err = client.Lstat(testFileName) assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) linkPath = filepath.Join(homeBasePath, defaultUsername, testFileName) err = os.Symlink(homeBasePath, linkPath) assert.NoError(t, err) err = sftpDownloadFile(testFileName, testFilePath, 0, client) assert.Error(t, err, "download file outside home dir must fail") err = sftpUploadFile(testFilePath, remoteDestPath, testFileSize, client) assert.Error(t, err, "overwrite a file outside home dir must fail") err = client.Chmod(remoteDestPath, 0644) assert.Error(t, err, "setstat on a file outside home dir must fail") err = os.Remove(linkPath) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(filepath.Join(homeBasePath, defaultUsername+"1")) assert.NoError(t, err) } func TestEscapeSFTPFsPrefix(t *testing.T) { usePubKey := false localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) u := getTestSFTPUser(usePubKey) sftpPrefix := "/prefix" outPrefix1 := "/pre" outPrefix2 := sftpPrefix + "1" out1 := "out1" out2 := "out2" u.FsConfig.SFTPConfig.Prefix = sftpPrefix sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(localUser, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir(sftpPrefix) assert.NoError(t, err) err = client.Mkdir(outPrefix1) assert.NoError(t, err) err = client.Mkdir(outPrefix2) assert.NoError(t, err) err = client.Symlink(outPrefix1, path.Join(sftpPrefix, out1)) assert.NoError(t, err) err = client.Symlink(outPrefix2, path.Join(sftpPrefix, out2)) assert.NoError(t, err) } conn, client, err = getSftpClient(sftpUser, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() contents, err := client.ReadDir("/") assert.NoError(t, err) assert.Len(t, contents, 2) _, err = client.ReadDir(out1) assert.Error(t, err) _, err = client.ReadDir(out2) assert.Error(t, err) err = client.Mkdir(path.Join(out1, "subout1")) assert.Error(t, err) err = client.Mkdir(path.Join(out2, "subout2")) assert.Error(t, err) } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestGetMimeTypeSFTPFs(t *testing.T) { usePubKey := false localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(localUser, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) if assert.NoError(t, err) { testData := []byte("some UTF-8 text so we should get a text/plain mime type") n, err := sftpFile.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) err = sftpFile.Close() assert.NoError(t, err) } } sftpUser.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) sftpUser.FsConfig.SFTPConfig.PrivateKey = kms.NewEmptySecret() fs, err := sftpUser.GetFilesystem("connID") if assert.NoError(t, err) { assert.True(t, vfs.IsSFTPFs(fs)) mime, err := fs.GetMimeType(testFileName) assert.NoError(t, err) assert.Equal(t, "text/plain; charset=utf-8", mime) } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestHomeSpecialChars(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.HomeDir = filepath.Join(homeBasePath, "abc açà#&%lk") user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) files, err := client.ReadDir(".") assert.NoError(t, err) assert.Equal(t, 1, len(files)) err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLogin(t *testing.T) { u := getTestUser(false) u.PublicKeys = []string{testPubKey} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, false) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.LastLogin, int64(0), "last login must be updated after a successful login: %v", user.LastLogin) } conn, client, err = getSftpClient(user, true) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user.Password = "invalid password" conn, client, err = getSftpClient(user, false) if !assert.Error(t, err, "login with invalid password must fail") { client.Close() conn.Close() } // testPubKey1 is not authorized user.PublicKeys = []string{testPubKey1} user.Password = "" _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, true) if !assert.Error(t, err, "login with invalid public key must fail") { defer conn.Close() defer client.Close() } // login a user with multiple public keys, only the second one is valid user.PublicKeys = []string{testPubKey1, testPubKey} user.Password = "" _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, true) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginUserCert(t *testing.T) { u := getTestUser(true) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) // try login using a cert signed from a trusted CA signer, err := getSignerForUserCert([]byte(testCertValid)) assert.NoError(t, err) conn, client, err := getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } // revoke the certificate certs := []string{"SHA256:OkxVB1ImSJ2XeI8nA2Wg+6zJVlxdevD1FYBSEJjFEN4"} data, err := json.Marshal(certs) assert.NoError(t, err) err = os.WriteFile(revokeUserCerts, data, 0644) assert.NoError(t, err) err = sftpd.Reload() assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if !assert.Error(t, err) { client.Close() conn.Close() } // if we remove the revoked certificate login should work again certs = []string{"SHA256:bsBRHC/xgiqBJdSuvSTNpJNLTISP/G356jNMCRYC5Es, SHA256:1kxVB1ImSJ2XeI8nA2Wg+6zJVlxdevD1FYBSEJjFEN4"} data, err = json.Marshal(certs) assert.NoError(t, err) err = os.WriteFile(revokeUserCerts, data, 0644) assert.NoError(t, err) err = sftpd.Reload() assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } // try login using a cert signed from an untrusted CA signer, err = getSignerForUserCert([]byte(testCertUntrustedCA)) assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if !assert.Error(t, err) { client.Close() conn.Close() } // try login using an host certificate instead of an user certificate signer, err = getSignerForUserCert([]byte(testHostCert)) assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if !assert.Error(t, err) { client.Close() conn.Close() } // try login using a user certificate with an authorized source address different from localhost signer, err = getSignerForUserCert([]byte(testCertOtherSourceAddress)) assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if !assert.Error(t, err) { client.Close() conn.Close() } // try login using an expired certificate signer, err = getSignerForUserCert([]byte(testCertExpired)) assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if !assert.Error(t, err) { client.Close() conn.Close() } // try login using a certificate with no principals signer, err = getSignerForUserCert([]byte(testCertNoPrincipals)) assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if !assert.Error(t, err) { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) // the user does not exist signer, err = getSignerForUserCert([]byte(testCertValid)) assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if !assert.Error(t, err) { client.Close() conn.Close() } // now login with a username not in the set of valid principals for the given certificate u.Username += "1" user, _, err = httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) signer, err = getSignerForUserCert([]byte(testCertValid)) assert.NoError(t, err) conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") if !assert.Error(t, err) { client.Close() conn.Close() } err = os.WriteFile(revokeUserCerts, []byte(`[]`), 0644) assert.NoError(t, err) err = sftpd.Reload() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMultiStepLoginKeyAndPwd(t *testing.T) { u := getTestUser(true) u.Password = defaultPassword u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{ dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.SSHLoginMethodPublicKey, dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, }...) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, true) if !assert.Error(t, err, "login with public key is disallowed and must fail") { client.Close() conn.Close() } conn, client, err = getSftpClient(user, true) if !assert.Error(t, err, "login with password is disallowed and must fail") { client.Close() conn.Close() } signer, _ := ssh.ParsePrivateKey([]byte(testPrivateKey)) authMethods := []ssh.AuthMethod{ ssh.PublicKeys(signer), ssh.Password(defaultPassword), } conn, client, err = getCustomAuthSftpClient(user, authMethods, "") if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } conn, client, err = getCustomAuthSftpClient(user, authMethods, sftpSrvAddr2222) if !assert.Error(t, err, "password auth is disabled on port 2222, multi-step auth must fail") { client.Close() conn.Close() } authMethods = []ssh.AuthMethod{ ssh.Password(defaultPassword), ssh.PublicKeys(signer), } _, _, err = getCustomAuthSftpClient(user, authMethods, "") assert.Error(t, err, "multi step auth login with wrong order must fail") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMultiStepLoginKeyAndKeyInt(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser(true) u.Password = defaultPassword u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{ dataprovider.SSHLoginMethodKeyAndPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, }...) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), os.ModePerm) assert.NoError(t, err) conn, client, err := getSftpClient(user, true) if !assert.Error(t, err, "login with public key is disallowed and must fail") { client.Close() conn.Close() } signer, _ := ssh.ParsePrivateKey([]byte(testPrivateKey)) authMethods := []ssh.AuthMethod{ ssh.PublicKeys(signer), ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { return []string{"1", "2"}, nil }), } conn, client, err = getCustomAuthSftpClient(user, authMethods, "") if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } conn, client, err = getCustomAuthSftpClient(user, authMethods, sftpSrvAddr2222) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } authMethods = []ssh.AuthMethod{ ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { return []string{"1", "2"}, nil }), ssh.PublicKeys(signer), } _, _, err = getCustomAuthSftpClient(user, authMethods, "") assert.Error(t, err, "multi step auth login with wrong order must fail") authMethods = []ssh.AuthMethod{ ssh.PublicKeys(signer), ssh.Password(defaultPassword), } _, _, err = getCustomAuthSftpClient(user, authMethods, "") assert.Error(t, err, "multi step auth login with wrong method must fail") user.Filters.DeniedLoginMethods = nil user.Filters.DeniedLoginMethods = append(user.Filters.DeniedLoginMethods, []string{ dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.SSHLoginMethodPublicKey, dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, }...) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, _, err = getCustomAuthSftpClient(user, authMethods, sftpSrvAddr2222) assert.Error(t, err) conn, client, err = getCustomAuthSftpClient(user, authMethods, "") if assert.NoError(t, err) { assert.NoError(t, checkBasicSFTP(client)) client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMultiStepLoginCertAndPwd(t *testing.T) { u := getTestUser(true) u.Password = defaultPassword u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{ dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.SSHLoginMethodPublicKey, dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, }...) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) signer, err := getSignerForUserCert([]byte(testCertValid)) assert.NoError(t, err) authMethods := []ssh.AuthMethod{ ssh.PublicKeys(signer), ssh.Password(defaultPassword), } conn, client, err := getCustomAuthSftpClient(user, authMethods, "") if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } signer, err = getSignerForUserCert([]byte(testCertOtherSourceAddress)) assert.NoError(t, err) authMethods = []ssh.AuthMethod{ ssh.PublicKeys(signer), ssh.Password(defaultPassword), } conn, client, err = getCustomAuthSftpClient(user, authMethods, "") if !assert.Error(t, err) { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginUserStatus(t *testing.T) { usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.LastLogin, int64(0), "last login must be updated after a successful login: %v", user.LastLogin) } user.Status = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if !assert.Error(t, err, "login for a disabled user must fail") { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginUserExpiration(t *testing.T) { usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.LastLogin, int64(0), "last login must be updated after a successful login: %v", user.LastLogin) } user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now()) - 120000 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if !assert.Error(t, err, "login for an expired user must fail") { client.Close() conn.Close() } user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now()) + 120000 _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginWithDatabaseCredentials(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "testbucket" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account", "private_key": " ", "client_email": "example@iam.gserviceaccount.com" }`) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) assert.NotEmpty(t, user.FsConfig.GCSConfig.Credentials.GetPayload()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginInvalidFs(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if !assert.Error(t, err, "login must fail, the user has an invalid filesystem config") { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestDeniedProtocols(t *testing.T) { u := getTestUser(true) u.Filters.DeniedProtocols = []string{common.ProtocolSSH} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, true) if !assert.Error(t, err, "SSH protocol is disabled, authentication must fail") { client.Close() conn.Close() } user.Filters.DeniedProtocols = []string{common.ProtocolFTP, common.ProtocolWebDAV} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, true) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestDeniedLoginMethods(t *testing.T) { u := getTestUser(true) u.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.LoginMethodPassword} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, true) if !assert.Error(t, err, "public key login is disabled, authentication must fail") { client.Close() conn.Close() } user.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.LoginMethodPassword} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, true) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user.Password = defaultPassword user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, false) if !assert.Error(t, err, "password login is disabled, authentication must fail") { client.Close() conn.Close() } user.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodPublicKey} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, false) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginWithIPFilters(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Filters.DeniedIP = []string{"192.167.0.0/24", "172.18.0.0/16"} u.Filters.AllowedIP = []string{} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.LastLogin, int64(0), "last login must be updated after a successful login: %v", user.LastLogin) } user.Filters.AllowedIP = []string{"127.0.0.0/8"} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user.Filters.AllowedIP = []string{"172.19.0.0/16"} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if !assert.Error(t, err, "login from an not allowed IP must fail") { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginEmptyPassword(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Password = "" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Password = "empty" _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginAnonymousUser(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Password = "" u.Filters.IsAnonymous = true _, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.Error(t, err) user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.IsAnonymous) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestAnonymousGroupInheritance(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.IsAnonymous = true group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) usePubKey := false u := getTestUser(usePubKey) u.Groups = []sdk.GroupMapping{ { Name: group.Name, Type: sdk.GroupTypePrimary, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) } func TestLoginAfterUserUpdateEmptyPwd(t *testing.T) { usePubKey := false user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) user.Password = "" // password should remain unchanged _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginKeyboardInteractiveAuth(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } user, _, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), os.ModePerm) assert.NoError(t, err) conn, client, err := getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user.Status = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) if !assert.Error(t, err, "keyboard interactive auth must fail the user is disabled") { client.Close() conn.Close() } user.Status = 1 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, -1), os.ModePerm) assert.NoError(t, err) conn, client, err = getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) if !assert.Error(t, err, "keyboard interactive auth must fail the script returned -1") { client.Close() conn.Close() } err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, true, 1), os.ModePerm) assert.NoError(t, err) conn, client, err = getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) if !assert.Error(t, err, "keyboard interactive auth must fail the script returned bad json") { client.Close() conn.Close() } err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 5, true, 1), os.ModePerm) assert.NoError(t, err) conn, client, err = getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) if !assert.Error(t, err, "keyboard interactive auth must fail the script returned bad json") { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestInteractiveLoginWithPasscode(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } user, _, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) assert.NoError(t, err) // test password check err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptForBuiltinChecks(false, 1), os.ModePerm) assert.NoError(t, err) conn, client, err := getKeyboardInteractiveSftpClient(user, []string{defaultPassword}) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } // wrong password _, _, err = getKeyboardInteractiveSftpClient(user, []string{"wrong_password"}) assert.Error(t, err) // correct password but the script returns an error err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptForBuiltinChecks(false, 0), os.ModePerm) assert.NoError(t, err) _, _, err = getKeyboardInteractiveSftpClient(user, []string{"wrong_password"}) assert.Error(t, err) // add multi-factor authentication configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) passcode, err := totp.GenerateCodeCustom(key.Secret(), time.Now(), totp.ValidateOpts{ Period: 30, Skew: 1, Digits: otp.DigitsSix, Algorithm: otp.AlgorithmSHA1, }) assert.NoError(t, err) err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptForBuiltinChecks(true, 1), os.ModePerm) assert.NoError(t, err) passwordAsked := false passcodeAsked := false authMethods := []ssh.AuthMethod{ ssh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) { var answers []string if strings.HasPrefix(questions[0], "Password") { answers = append(answers, defaultPassword) passwordAsked = true } else { answers = append(answers, passcode) passcodeAsked = true } return answers, nil }), } conn, client, err = getCustomAuthSftpClient(user, authMethods, "") if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } assert.True(t, passwordAsked) assert.True(t, passcodeAsked) // the same passcode cannot be reused _, _, err = getCustomAuthSftpClient(user, authMethods, "") assert.Error(t, err) // correct passcode but the script returns an error configName, key, _, err = mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) passcode, err = totp.GenerateCodeCustom(key.Secret(), time.Now(), totp.ValidateOpts{ Period: 30, Skew: 1, Digits: otp.DigitsSix, Algorithm: otp.AlgorithmSHA1, }) assert.NoError(t, err) err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptForBuiltinChecks(true, 0), os.ModePerm) assert.NoError(t, err) passwordAsked = false passcodeAsked = false _, _, err = getCustomAuthSftpClient(user, authMethods, "") assert.Error(t, err) authMethods = []ssh.AuthMethod{ ssh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) { var answers []string if strings.HasPrefix(questions[0], "Password") { answers = append(answers, defaultPassword) passwordAsked = true } else { answers = append(answers, passcode) passcodeAsked = true } return answers, nil }), } _, _, err = getCustomAuthSftpClient(user, authMethods, "") assert.Error(t, err) assert.True(t, passwordAsked) assert.True(t, passcodeAsked) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMustChangePasswordRequirement(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Filters.RequirePasswordChange = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) // public key auth works even if the user must change password conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } // password auth does not work _, _, err = getSftpClient(user, false) assert.Error(t, err) // change password err = dataprovider.UpdateUserPassword(user.Username, defaultPassword, "", "", "") assert.NoError(t, err) conn, client, err = getSftpClient(user, false) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSecondFactorRequirement(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestNamingRules(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.NamingRules = 7 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) usePubKey := true u := getTestUser(usePubKey) u.Username = "useR@user.com " user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, "user@user.com", user.Username) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } u.Password = defaultPassword _, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(u, false) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(u, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(u.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestPreLoginScript(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := true u := getTestUser(usePubKey) mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) folderMountPath := "/vpath" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: folderMountPath, }) err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) providerConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testData := []byte("test data") err = os.WriteFile(testFilePath, testData, 0666) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(folderMountPath, testFileName), int64(len(testData)), client) assert.NoError(t, err) info, err := os.Stat(filepath.Join(mappedPath, testFileName)) if assert.NoError(t, err) { assert.Greater(t, info.Size(), int64(len(testData))) } } err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err, "pre-login script returned a non json response, login must fail") { client.Close() conn.Close() } // now disable the the hook user.Filters.Hooks.PreLoginDisabled = true user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user.Filters.Hooks.PreLoginDisabled = false user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user.Status = 0 err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err, "pre-login script returned a disabled user, login must fail") { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestPreLoginUserCreation(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false u := getTestUser(usePubKey) u.Permissions["/list"] = []string{"list", "download"} err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) providerConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Permissions, 2) assert.Empty(t, user.Description) u.Description = "some desc" delete(u.Permissions, "/list") err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) // The user should be updated and list permission removed conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Permissions, 1) assert.NotEmpty(t, user.Description) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestPreLoginHookPreserveMFAConfig(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false u := getTestUser(usePubKey) err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) providerConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } // add multi-factor authentication user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Filters.RecoveryCodes, 0) assert.False(t, user.Filters.TOTPConfig.Enabled) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } for i := 0; i < 12; i++ { user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, dataprovider.RecoveryCode{ Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))), }) } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Filters.RecoveryCodes, 12) assert.True(t, user.Filters.TOTPConfig.Enabled) assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName) assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) err = os.WriteFile(extAuthPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Filters.RecoveryCodes, 12) assert.True(t, user.Filters.TOTPConfig.Enabled) assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName) assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestPreDownloadHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldExecuteOn := common.Config.Actions.ExecuteOn oldHook := common.Config.Actions.Hook common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} common.Config.Actions.Hook = preDownloadPath usePubKey := true u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) } remoteSCPDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) err = scpDownload(localDownloadPath, remoteSCPDownPath, false, false) assert.NoError(t, err) err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Remove(testFileName) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.ErrorIs(t, err, os.ErrPermission) } err = scpDownload(localDownloadPath, remoteSCPDownPath, false, false) assert.Error(t, err) common.Config.Actions.Hook = "http://127.0.0.1:8080/web/admin/login" conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Remove(testFileName) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) } err = scpDownload(localDownloadPath, remoteSCPDownPath, false, false) assert.NoError(t, err) common.Config.Actions.Hook = "http://127.0.0.1:8080/" conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Remove(testFileName) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.ErrorIs(t, err, os.ErrPermission) } err = scpDownload(localDownloadPath, remoteSCPDownPath, false, false) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook } func TestPreUploadHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldExecuteOn := common.Config.Actions.ExecuteOn oldHook := common.Config.Actions.Hook common.Config.Actions.ExecuteOn = []string{common.OperationPreUpload} common.Config.Actions.Hook = preUploadPath usePubKey := true u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(preUploadPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) } remoteSCPUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) err = scpUpload(testFilePath, remoteSCPUpPath, true, false) assert.NoError(t, err) err = os.WriteFile(preUploadPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.ErrorIs(t, err, os.ErrPermission) err = sftpUploadFile(testFilePath, testFileName+"1", testFileSize, client) assert.ErrorIs(t, err, os.ErrPermission) } err = scpUpload(testFilePath, remoteSCPUpPath, true, false) assert.Error(t, err) common.Config.Actions.Hook = "http://127.0.0.1:8080/web/client/login" conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) } err = scpUpload(testFilePath, remoteSCPUpPath, true, false) assert.NoError(t, err) common.Config.Actions.Hook = "http://127.0.0.1:8080/web" conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.ErrorIs(t, err, os.ErrPermission) err = sftpUploadFile(testFilePath, testFileName+"1", testFileSize, client) assert.ErrorIs(t, err, os.ErrPermission) } err = scpUpload(testFilePath, remoteSCPUpPath, true, false) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook } func TestPostConnectHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } common.Config.PostConnectHook = postConnectPath usePubKey := true u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(postConnectPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } err = os.WriteFile(postConnectPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err) { client.Close() conn.Close() } common.Config.PostConnectHook = "http://127.0.0.1:8080/healthz" conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } common.Config.PostConnectHook = "http://127.0.0.1:8080/notfound" conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err) { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.PostConnectHook = "" } func TestCheckPwdHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 1000 err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(checkPwdPath, getCheckPwdScriptsContents(2, defaultPassword), os.ModePerm) assert.NoError(t, err) providerConf.CheckPasswordHook = checkPwdPath providerConf.CheckPasswordScope = 1 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { err = checkBasicSFTP(client) assert.NoError(t, err) client.Close() conn.Close() } err = os.WriteFile(checkPwdPath, getCheckPwdScriptsContents(0, defaultPassword), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if !assert.Error(t, err) { client.Close() conn.Close() } // now disable the the hook user.Filters.Hooks.CheckPasswordDisabled = true user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { err = checkBasicSFTP(client) assert.NoError(t, err) client.Close() conn.Close() } // enable the hook again user.Filters.Hooks.CheckPasswordDisabled = false user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) err = os.WriteFile(checkPwdPath, getCheckPwdScriptsContents(1, ""), os.ModePerm) assert.NoError(t, err) user.Password = defaultPassword + "1" conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { err = checkBasicSFTP(client) assert.NoError(t, err) client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) providerConf.CheckPasswordScope = 6 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) user, _, err = httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Password = defaultPassword + "1" conn, client, err = getSftpClient(user, usePubKey) if !assert.Error(t, err) { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(checkPwdPath) assert.NoError(t, err) } func TestLoginExternalAuthPwdAndPubKey(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := true u := getTestUser(usePubKey) u.QuotaFiles = 1000 err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) testFileSize := int64(65535) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } u.Username = defaultUsername + "1" conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err, "external auth login with invalid user must fail") { client.Close() conn.Close() } usePubKey = false u = getTestUser(usePubKey) u.PublicKeys = []string{} err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, len(user.PublicKeys)) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) u.Status = 0 err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err) { client.Close() conn.Close() } // now disable the the hook user.Filters.Hooks.ExternalAuthDisabled = true user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestExternalAuthMultiStepLoginKeyAndPwd(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser(true) u.Password = defaultPassword u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{ dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.SSHLoginMethodPublicKey, dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, }...) err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) assert.NoError(t, err) authMethods := []ssh.AuthMethod{ ssh.PublicKeys(signer), ssh.Password(defaultPassword), } conn, client, err := getCustomAuthSftpClient(u, authMethods, "") if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } // wrong sequence should fail authMethods = []ssh.AuthMethod{ ssh.Password(defaultPassword), ssh.PublicKeys(signer), } _, _, err = getCustomAuthSftpClient(u, authMethods, "") assert.Error(t, err) // public key only auth must fail _, _, err = getSftpClient(u, true) assert.Error(t, err) // password only auth must fail _, _, err = getSftpClient(u, false) assert.Error(t, err) _, err = httpdtest.RemoveUser(u, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(u.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestExternalAuthEmptyResponse(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 1000 err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) testFileSize := int64(65535) // the user will be created conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, len(user.PublicKeys)) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) // now modify the user user.MaxSessions = 10 user.QuotaFiles = 100 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, true, ""), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 10, user.MaxSessions) assert.Equal(t, 100, user.QuotaFiles) // the auth script accepts any password and returns an empty response, the // user password must be updated u.Password = defaultUsername conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestExternalAuthDifferentUsername(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false extAuthUsername := "common_user" u := getTestUser(usePubKey) u.QuotaFiles = 1000 err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, extAuthUsername), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) // the user logins using "defaultUsername" and the external auth returns "extAuthUsername" testFileSize := int64(65535) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } // logins again to test that used quota is preserved conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = checkBasicSFTP(client) assert.NoError(t, err) } _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) assert.NoError(t, err) user, _, err := httpdtest.GetUserByUsername(extAuthUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, len(user.PublicKeys)) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestLoginExternalAuth(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } mappedPath := filepath.Join(os.TempDir(), "vdir1") folderName := filepath.Base(mappedPath) extAuthScopes := []int{1, 2} for _, authScope := range extAuthScopes { var usePubKey bool if authScope == 1 { usePubKey = false } else { usePubKey = true } u := getTestUser(usePubKey) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/vpath", QuotaFiles: 1 + authScope, QuotaSize: 10 + int64(authScope), }) err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = authScope err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err = httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } if !usePubKey { dbUser, err := dataprovider.UserExists(defaultUsername, "") assert.NoError(t, err) found, match := dataprovider.CheckCachedUserPassword(defaultUsername, defaultPassword, dbUser.Password) assert.True(t, found) assert.True(t, match) } u.Username = defaultUsername + "1" conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err, "external auth login with invalid user must fail") { client.Close() conn.Close() } usePubKey = !usePubKey u = getTestUser(usePubKey) conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err, "external auth login with valid user but invalid auth scope must fail") { client.Close() conn.Close() } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) if assert.Len(t, user.VirtualFolders, 1) { folder := user.VirtualFolders[0] assert.Equal(t, folderName, folder.Name) assert.Equal(t, mappedPath, folder.MappedPath) assert.Equal(t, 1+authScope, folder.QuotaFiles) assert.Equal(t, 10+int64(authScope), folder.QuotaSize) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } } func TestLoginExternalAuthCache(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser(false) u.Filters.ExternalAuthCacheTime = 120 err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 1 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) conn, client, err := getSftpClient(u, false) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) lastLogin := user.LastLogin assert.Greater(t, lastLogin, int64(0)) assert.Equal(t, u.Filters.ExternalAuthCacheTime, user.Filters.ExternalAuthCacheTime) // the auth should be now cached so update the hook to return an error err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, true, false, ""), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, false) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, lastLogin, user.LastLogin) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestLoginExternalAuthInteractive(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false u := getTestUser(usePubKey) err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 4 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), os.ModePerm) assert.NoError(t, err) conn, client, err := getKeyboardInteractiveSftpClient(u, []string{"1", "2"}) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } u.Username = defaultUsername + "1" conn, client, err = getKeyboardInteractiveSftpClient(u, []string{"1", "2"}) if !assert.Error(t, err, "external auth login with invalid user must fail") { client.Close() conn.Close() } usePubKey = true u = getTestUser(usePubKey) conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err, "external auth login with valid user but invalid auth scope must fail") { client.Close() conn.Close() } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestLoginExternalAuthErrors(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := true u := getTestUser(usePubKey) err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, true, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) if !assert.Error(t, err, "login must fail, external auth returns a non json response") { client.Close() conn.Close() } usePubKey = false u = getTestUser(usePubKey) conn, client, err = getSftpClient(u, usePubKey) if !assert.Error(t, err, "login must fail, external auth returns a non json response") { client.Close() conn.Close() } _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestExternalAuthReturningAnonymousUser(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false u := getTestUser(usePubKey) u.Filters.IsAnonymous = true u.Password = "" err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) _, _, err = getSftpClient(u, usePubKey) assert.Error(t, err) user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.IsAnonymous) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) // test again, the user now exists _, _, err = getSftpClient(u, usePubKey) assert.Error(t, err) updatedUser, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) user.UpdatedAt = updatedUser.UpdatedAt user.LastPasswordChange = updatedUser.LastPasswordChange assert.Equal(t, user, updatedUser) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestExternalAuthPreserveMFAConfig(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false u := getTestUser(usePubKey) err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } // add multi-factor authentication user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Filters.RecoveryCodes, 0) assert.False(t, user.Filters.TOTPConfig.Enabled) configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) assert.NoError(t, err) user.Password = defaultPassword user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ Enabled: true, ConfigName: configName, Secret: kms.NewPlainSecret(key.Secret()), Protocols: []string{common.ProtocolSSH}, } for i := 0; i < 12; i++ { user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, dataprovider.RecoveryCode{ Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))), }) } err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) // login again and check that the MFA configs are preserved conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Filters.RecoveryCodes, 12) assert.True(t, user.Filters.TOTPConfig.Enabled) assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName) assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, true, ""), os.ModePerm) assert.NoError(t, err) conn, client, err = getSftpClient(u, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Len(t, user.Filters.RecoveryCodes, 12) assert.True(t, user.Filters.TOTPConfig.Enabled) assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName) assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestQuotaDisabledError(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.TrackQuota = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName+"1", testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName+"1", testFileName+".rename") //nolint:goconst assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } //nolint:dupl func TestMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) usePubKey := true user := getTestUser(usePubKey) err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) user.Password = "" conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { assert.NoError(t, checkBasicSFTP(client)) s, c, err := getSftpClient(user, usePubKey) if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") { c.Close() s.Close() } err = client.Close() assert.NoError(t, err) err = conn.Close() assert.NoError(t, err) } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.MaxTotalConnections = oldValue } //nolint:dupl func TestMaxPerHostConnections(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 1 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) usePubKey := true user := getTestUser(usePubKey) err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) user.Password = "" conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { assert.NoError(t, checkBasicSFTP(client)) s, c, err := getSftpClient(user, usePubKey) if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") { c.Close() s.Close() } err = client.Close() assert.NoError(t, err) err = conn.Close() assert.NoError(t, err) } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.MaxPerHostConnections = oldValue } func TestMaxTransfers(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 2 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) usePubKey := true user := getTestUser(usePubKey) err := dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) user.Password = "" conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { assert.NoError(t, checkBasicSFTP(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) f1, err := client.Create("file1") assert.NoError(t, err) f2, err := client.Create("file2") assert.NoError(t, err) _, err = f1.Write([]byte(" ")) assert.NoError(t, err) _, err = f2.Write([]byte(" ")) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.ErrorContains(t, err, sftp.ErrSSHFxPermissionDenied.Error()) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") err = scpUpload(testFilePath, remoteUpPath, false, false) assert.Error(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.ErrorContains(t, err, sftp.ErrSSHFxPermissionDenied.Error()) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) err = scpDownload(localDownloadPath, remoteDownPath, false, false) assert.Error(t, err) err = f1.Close() assert.NoError(t, err) err = f2.Close() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Close() assert.NoError(t, err) err = conn.Close() assert.NoError(t, err) } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return common.Connections.GetTotalTransfers() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) common.Config.MaxPerHostConnections = oldValue } func TestMaxSessions(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Username += "1" u.MaxSessions = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) s, c, err := getSftpClient(user, usePubKey) if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") { c.Close() s.Close() } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSupportedExtensions(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) v, ok := client.HasExtension("statvfs@openssh.com") assert.Equal(t, "2", v) assert.True(t, ok) _, ok = client.HasExtension("hardlink@openssh.com") assert.False(t, ok) _, ok = client.HasExtension("posix-rename@openssh.com") assert.False(t, ok) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestQuotaFileReplace(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 1000 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.QuotaFiles = 1000 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileSize := int64(65535) testFilePath := filepath.Join(homeBasePath, testFileName) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { //nolint:dupl defer conn.Close() defer client.Close() expectedQuotaSize := testFileSize expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) // now replace the same file, the quota must not change err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) // now create a symlink, replace it with a file and check the quota // replacing a symlink is like uploading a new file err = client.Symlink(testFileName, testFileName+".link") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) expectedQuotaFiles++ expectedQuotaSize += testFileSize err = sftpUploadFile(testFilePath, testFileName+".link", testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) } // now set a quota size restriction and upload the same file, upload should fail for space limit exceeded user.QuotaSize = testFileSize*2 - 1 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.Error(t, err, "quota size exceeded, file upload must fail") err = client.Remove(testFileName) assert.NoError(t, err) } if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 user.QuotaSize = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestQuotaRename(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.QuotaFiles = 1000 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.QuotaFiles = 1000 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileSize := int64(65535) testFileSize1 := int64(65537) testFileName1 := "test_file1.dat" //nolint:goconst testFilePath := filepath.Join(homeBasePath, testFileName) testFilePath1 := filepath.Join(homeBasePath, testFileName1) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName, testFileName+".rename") assert.NoError(t, err) err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) err = client.Rename(testFileName1, testFileName+".rename") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) err = client.Symlink(testFileName+".rename", testFileName+".symlink") //nolint:goconst assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) // overwrite a symlink err = client.Rename(testFileName, testFileName+".symlink") assert.NoError(t, err) err = client.Mkdir("testdir") assert.NoError(t, err) err = client.Rename("testdir", "testdir1") assert.NoError(t, err) err = client.Mkdir("testdir") assert.NoError(t, err) err = client.Rename("testdir", "testdir1") assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) testDir := "tdir" err = client.Mkdir(testDir) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 4, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2+testFileSize1*2, user.UsedQuotaSize) err = client.Rename(testDir, testDir+"1") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 4, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2+testFileSize1*2, user.UsedQuotaSize) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestQuotaScan(t *testing.T) { usePubKey := false user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) testFileSize := int64(65535) expectedQuotaSize := user.UsedQuotaSize + testFileSize expectedQuotaFiles := user.UsedQuotaFiles + 1 conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) // create user with the same home dir, so there is at least an untracked file user, _, err = httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) assert.NoError(t, err) assert.Eventually(t, func() bool { scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) if err == nil { return len(scans) == 0 } return false }, 1*time.Second, 50*time.Millisecond) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMultipleQuotaScans(t *testing.T) { res := common.QuotaScans.AddUserQuotaScan(defaultUsername, "") assert.True(t, res) res = common.QuotaScans.AddUserQuotaScan(defaultUsername, "") assert.False(t, res, "add quota must fail if another scan is already active") assert.True(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername)) activeScans := common.QuotaScans.GetUsersQuotaScans("") assert.Equal(t, 0, len(activeScans)) assert.False(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername)) } func TestQuotaLimits(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 1 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.QuotaFiles = 1 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileSize := int64(65535) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testFileSize1 := int64(131072) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) testFileSize2 := int64(32768) testFileName2 := "test_file2.dat" //nolint:goconst testFilePath2 := filepath.Join(homeBasePath, testFileName2) err = createTestFile(testFilePath2, testFileSize2) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { // test quota files conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName+".quota", testFileSize, client) //nolint:goconst assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName+".quota.1", testFileSize, client) if assert.Error(t, err, "user is over quota files, upload must fail") { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } // rename should work err = client.Rename(testFileName+".quota", testFileName) assert.NoError(t, err) } // test quota size user.QuotaSize = testFileSize - 1 user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName+".quota.1", testFileSize, client) if assert.Error(t, err, "user is over quota size, upload must fail") { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } err = client.Rename(testFileName, testFileName+".quota") assert.NoError(t, err) err = client.Rename(testFileName+".quota", testFileName) assert.NoError(t, err) } // now test quota limits while uploading the current file, we have 1 bytes remaining user.QuotaSize = testFileSize + 1 user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "SSH_FX_FAILURE") if user.Username == localUser.Username { assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } } _, err = client.Stat(testFileName1) assert.Error(t, err) _, err = client.Lstat(testFileName1) assert.Error(t, err) // overwriting an existing file will work if the resulting size is lesser or equal than the current one err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath2, testFileName, testFileSize2, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, testFileName, testFileSize1, client) assert.Error(t, err) _, err := client.Stat(testFileName) assert.Error(t, err) } if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) err = os.Remove(testFilePath2) assert.NoError(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestTransferQuotaLimits(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.DownloadDataTransfer = 1 u.UploadDataTransfer = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(550000) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) // error while download is active err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) } // error before starting the download err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) } // error while upload is active err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } // error before starting the upload err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.UsedDownloadDataTransfer, int64(1024*1024)) assert.Greater(t, user.UsedUploadDataTransfer, int64(1024*1024)) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUploadMaxSize(t *testing.T) { testFileSize := int64(65535) usePubKey := false u := getTestUser(usePubKey) u.Filters.MaxUploadFileSize = testFileSize + 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testFileSize1 := int64(131072) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) assert.Error(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) // now test overwrite an existing file with a size bigger than the allowed one err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName1), testFileSize1) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) assert.Error(t, err) } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestBandwidthAndConnections(t *testing.T) { usePubKey := false testFileSize := int64(524288) u := getTestUser(usePubKey) u.UploadBandwidth = 120 u.DownloadBandwidth = 100 wantedUploadElapsed := 1000 * (testFileSize / 1024) / u.UploadBandwidth wantedDownloadElapsed := 1000 * (testFileSize / 1024) / u.DownloadBandwidth // 100 ms tolerance wantedUploadElapsed -= 100 wantedDownloadElapsed -= 100 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) startTime := time.Now() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) elapsed := time.Since(startTime).Nanoseconds() / 1000000 assert.GreaterOrEqual(t, elapsed, wantedUploadElapsed, "upload bandwidth throttling not respected") startTime = time.Now() localDownloadPath := filepath.Join(homeBasePath, testDLFileName) c := sftpDownloadNonBlocking(testFileName, localDownloadPath, testFileSize, client) waitForActiveTransfers(t) // wait some additional arbitrary time to wait for transfer activity to happen // it is need to reach all the code in CheckIdleConnections time.Sleep(100 * time.Millisecond) err = <-c assert.NoError(t, err) elapsed = time.Since(startTime).Nanoseconds() / 1000000 assert.GreaterOrEqual(t, elapsed, wantedDownloadElapsed, "download bandwidth throttling not respected") // test disconnection c = sftpUploadNonBlocking(testFilePath, testFileName+"_partial", testFileSize, client) waitForActiveTransfers(t) time.Sleep(100 * time.Millisecond) for _, stat := range common.Connections.GetStats("") { common.Connections.Close(stat.ConnectionID, "") } err = <-c assert.Error(t, err, "connection closed while uploading: the upload must fail") assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 10*time.Second, 200*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPatternsFilters(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName+".zip", testFileSize, client) assert.NoError(t, err) } user.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", AllowedPatterns: []string{"*.zIp"}, DeniedPatterns: []string{}, }, } user.Filters.DisableFsChecks = true _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.Error(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err) err = client.Rename(testFileName, testFileName+"1") assert.Error(t, err) err = client.Remove(testFileName) assert.Error(t, err) err = sftpDownloadFile(testFileName+".zip", localDownloadPath, testFileSize, client) assert.NoError(t, err) err = client.Mkdir("dir.zip") assert.NoError(t, err) err = client.Rename("dir.zip", "dir1.zip") assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestVirtualFolders(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) vdirPath := "/vdir/subdir" testDir := "/userDir" testDir1 := "/userDir1" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, }) u.Permissions[testDir] = []string{dataprovider.PermCreateDirs} u.Permissions[testDir1] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermRename} u.Permissions[path.Join(testDir1, "subdir")] = []string{dataprovider.PermRename} f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() // check virtual folder auto creation _, err = os.Stat(mappedPath) assert.NoError(t, err) testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpUploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) err = client.Rename(vdirPath, "new_name") assert.Error(t, err, "renaming a virtual folder must fail") err = client.RemoveDirectory(vdirPath) assert.Error(t, err, "removing a virtual folder must fail") err = client.Mkdir(vdirPath) assert.Error(t, err, "creating a virtual folder must fail") err = client.Symlink(path.Join(vdirPath, testFileName), vdirPath) assert.Error(t, err, "symlink to a virtual folder must fail") err = client.Rename("/vdir", "/vdir1") assert.Error(t, err, "renaming a directory with a virtual folder inside must fail") err = client.RemoveDirectory("/vdir") assert.Error(t, err, "removing a directory with a virtual folder inside must fail") err = client.Mkdir("vdir1") assert.NoError(t, err) // rename empty dir /vdir1, we have permission on / err = client.Rename("vdir1", "vdir2") assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join("vdir2", testFileName), testFileSize, client) assert.NoError(t, err) // we don't have rename permission in testDir and vdir2 contains a file err = client.Rename("vdir2", testDir) assert.Error(t, err) err = client.Rename("vdir2", testDir1) assert.NoError(t, err) err = client.Rename(testDir1, "vdir2") assert.NoError(t, err) err = client.MkdirAll(path.Join("vdir2", "subdir")) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join("vdir2", "subdir", testFileName), testFileSize, client) assert.NoError(t, err) err = client.Rename("vdir2", testDir1) assert.NoError(t, err) err = client.Rename(testDir1, "vdir2") assert.NoError(t, err) err = client.MkdirAll(path.Join("vdir2", "subdir", "subdir")) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join("vdir2", "subdir", "subdir", testFileName), testFileSize, client) assert.NoError(t, err) err = client.Rename("vdir2", testDir1) assert.NoError(t, err) err = client.Rename(testDir1, "vdir3") assert.NoError(t, err) err = client.Remove(path.Join("vdir3", "subdir", "subdir", testFileName)) assert.NoError(t, err) err = client.RemoveDirectory(path.Join("vdir3", "subdir", "subdir")) assert.NoError(t, err) err = client.Rename("vdir3", testDir1) assert.NoError(t, err) err = client.Rename(testDir1, "vdir2") assert.NoError(t, err) err = client.Symlink(path.Join("vdir2", "subdir", testFileName), path.Join("vdir2", "subdir", "alink")) assert.NoError(t, err) err = client.Rename("vdir2", testDir1) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestVirtualFoldersQuotaLimit(t *testing.T) { usePubKey := false u1 := getTestUser(usePubKey) u1.QuotaFiles = 1 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" //nolint:goconst mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" //nolint:goconst u1.VirtualFolders = append(u1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: -1, QuotaSize: -1, }) u1.VirtualFolders = append(u1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 1, QuotaSize: 0, }) testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) err := createTestFile(testFilePath, testFileSize) assert.NoError(t, err) u2 := getTestUser(usePubKey) u2.QuotaSize = testFileSize + 1 u2.VirtualFolders = append(u2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: -1, QuotaSize: -1, }) u2.VirtualFolders = append(u2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 0, QuotaSize: testFileSize + 1, }) users := []dataprovider.User{u1, u2} for _, u := range users { err = os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.Error(t, err) _, err = client.Stat(testFileName) assert.Error(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName+"1"), testFileSize, client) assert.Error(t, err) _, err = client.Stat(path.Join(vdirPath1, testFileName+"1")) assert.Error(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName+"1"), testFileSize, client) assert.Error(t, err) _, err = client.Stat(path.Join(vdirPath2, testFileName+"1")) assert.Error(t, err) err = client.Remove(path.Join(vdirPath1, testFileName)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) assert.Error(t, err) // now test renames err = client.Rename(testFileName, path.Join(vdirPath1, testFileName)) assert.NoError(t, err) err = client.Rename(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testFileName+".rename")) assert.NoError(t, err) err = client.Rename(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+".rename")) assert.NoError(t, err) err = client.Rename(path.Join(vdirPath2, testFileName+".rename"), testFileName+".rename") assert.Error(t, err) err = client.Rename(path.Join(vdirPath2, testFileName+".rename"), path.Join(vdirPath1, testFileName)) assert.Error(t, err) err = client.Rename(path.Join(vdirPath1, testFileName+".rename"), path.Join(vdirPath2, testFileName)) assert.Error(t, err) err = client.Rename(path.Join(vdirPath1, testFileName+".rename"), testFileName) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } err = os.Remove(testFilePath) assert.NoError(t, err) } func TestSFTPLoopSimple(t *testing.T) { usePubKey := false user1 := getTestSFTPUser(usePubKey) user2 := getTestSFTPUser(usePubKey) user1.Username += "1" user2.Username += "2" user1.FsConfig.Provider = sdk.SFTPFilesystemProvider user2.FsConfig.Provider = sdk.SFTPFilesystemProvider user1.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user2.Username, }, Password: kms.NewPlainSecret(defaultPassword), } user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user1.Username, }, Password: kms.NewPlainSecret(defaultPassword), } user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) assert.NoError(t, err, string(resp)) user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) assert.NoError(t, err, string(resp)) _, _, err = getSftpClient(user1, usePubKey) assert.Error(t, err) _, _, err = getSftpClient(user2, usePubKey) assert.Error(t, err) user1.FsConfig.SFTPConfig.Username = user1.Username user1.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) _, _, err = httpdtest.UpdateUser(user1, http.StatusOK, "") assert.NoError(t, err) _, _, err = getSftpClient(user1, usePubKey) assert.Error(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user2.GetHomeDir()) assert.NoError(t, err) } func TestSFTPLoopVirtualFolders(t *testing.T) { usePubKey := false sftpFloderName := "sftp" user1 := getTestUser(usePubKey) user2 := getTestSFTPUser(usePubKey) user3 := getTestSFTPUser(usePubKey) user1.Username += "1" user2.Username += "2" user3.Username += "3" // user1 is a local account with a virtual SFTP folder to user2 // user2 has user1 as SFTP fs user1.VirtualFolders = append(user1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: sftpFloderName, }, VirtualPath: "/vdir", }) user2.FsConfig.Provider = sdk.SFTPFilesystemProvider user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user1.Username, }, Password: kms.NewPlainSecret(defaultPassword), } user3.FsConfig.Provider = sdk.SFTPFilesystemProvider user3.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user1.Username, }, Password: kms.NewPlainSecret(defaultPassword), } f := vfs.BaseVirtualFolder{ Name: sftpFloderName, FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user2.Username, EqualityCheckMode: 1, }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) assert.NoError(t, err, string(resp)) user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) assert.NoError(t, err, string(resp)) user3, resp, err = httpdtest.AddUser(user3, http.StatusCreated) assert.NoError(t, err, string(resp)) // login will work but /vdir will not be accessible conn, client, err := getSftpClient(user1, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) _, err = client.ReadDir("/vdir") assert.Error(t, err) } // now make user2 a local account with an SFTP virtual folder to user1. // So we have: // user1 -> local account with the SFTP virtual folder /vdir to user2 // user2 -> local account with the SFTP virtual folder /vdir2 to user3 // user3 -> sftp user with user1 as fs user2.FsConfig.Provider = sdk.LocalFilesystemProvider user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} user2.VirtualFolders = append(user2.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: sftpFloderName, FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user3.Username, }, Password: kms.NewPlainSecret(defaultPassword), }, }, }, VirtualPath: "/vdir2", }) user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") assert.NoError(t, err) // login will work but /vdir will not be accessible conn, client, err = getSftpClient(user1, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) _, err = client.ReadDir("/vdir") assert.Error(t, err) } _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user2.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user3, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user3.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: sftpFloderName}, http.StatusOK) assert.NoError(t, err) } func TestNestedVirtualFolders(t *testing.T) { usePubKey := true baseUser, resp, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err, string(resp)) u := getTestSFTPUser(usePubKey) u.QuotaFiles = 1000 mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") folderNameCrypt := filepath.Base(mappedPathCrypt) vdirCryptPath := "/vdir/crypt" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameCrypt, }, VirtualPath: vdirCryptPath, QuotaFiles: 100, }) mappedPath := filepath.Join(os.TempDir(), "local") folderName := filepath.Base(mappedPath) vdirPath := "/vdir/local" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, QuotaFiles: -1, QuotaSize: -1, }) mappedPathNested := filepath.Join(os.TempDir(), "nested") folderNameNested := filepath.Base(mappedPathNested) vdirNestedPath := "/vdir/crypt/nested" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameNested, }, VirtualPath: vdirNestedPath, QuotaFiles: -1, QuotaSize: -1, }) f1 := vfs.BaseVirtualFolder{ Name: folderNameCrypt, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, MappedPath: mappedPathCrypt, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) f3 := vfs.BaseVirtualFolder{ Name: folderNameNested, MappedPath: mappedPathNested, } _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() expectedQuotaSize := int64(0) expectedQuotaFiles := 0 fileSize := int64(32765) err = writeSFTPFile(testFileName, fileSize, client) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ fileSize = 38764 err = writeSFTPFile(path.Join("/vdir", testFileName), fileSize, client) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ fileSize = 18769 err = writeSFTPFile(path.Join(vdirPath, testFileName), fileSize, client) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ fileSize = 27658 err = writeSFTPFile(path.Join(vdirNestedPath, testFileName), fileSize, client) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ fileSize = 39765 err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), fileSize, client) assert.NoError(t, err) userGet, _, err := httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, userGet.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, userGet.UsedQuotaSize) folderGet, _, err := httpdtest.GetFolderByName(folderNameCrypt, http.StatusOK) assert.NoError(t, err) assert.Greater(t, folderGet.UsedQuotaSize, fileSize) assert.Equal(t, 1, folderGet.UsedQuotaFiles) folderGet, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), folderGet.UsedQuotaSize) assert.Equal(t, 0, folderGet.UsedQuotaFiles) folderGet, _, err = httpdtest.GetFolderByName(folderNameNested, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), folderGet.UsedQuotaSize) assert.Equal(t, 0, folderGet.UsedQuotaFiles) files, err := client.ReadDir("/") if assert.NoError(t, err) { assert.Len(t, files, 2) } info, err := client.Stat("vdir") if assert.NoError(t, err) { assert.True(t, info.IsDir()) } files, err = client.ReadDir("/vdir") if assert.NoError(t, err) { assert.Len(t, files, 3) } files, err = client.ReadDir(vdirCryptPath) if assert.NoError(t, err) { assert.Len(t, files, 2) } info, err = client.Stat(vdirNestedPath) if assert.NoError(t, err) { assert.True(t, info.IsDir()) } // finally add some files directly using os method and then check quota fName := "testfile" fileSize = 123456 err = createTestFile(filepath.Join(baseUser.HomeDir, fName), fileSize) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ fileSize = 8765 err = createTestFile(filepath.Join(mappedPath, fName), fileSize) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ fileSize = 98751 err = createTestFile(filepath.Join(mappedPathNested, fName), fileSize) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ err = createTestFile(filepath.Join(mappedPathCrypt, fName), fileSize) assert.NoError(t, err) _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) assert.NoError(t, err) assert.Eventually(t, func() bool { scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) if err == nil { return len(scans) == 0 } return false }, 1*time.Second, 50*time.Millisecond) userGet, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, userGet.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, userGet.UsedQuotaSize) // the crypt folder is not included within user quota so we need to do a separate scan _, err = httpdtest.StartFolderQuotaScan(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusAccepted) assert.NoError(t, err) assert.Eventually(t, func() bool { scans, _, err := httpdtest.GetFoldersQuotaScans(http.StatusOK) if err == nil { return len(scans) == 0 } return false }, 1*time.Second, 50*time.Millisecond) folderGet, _, err = httpdtest.GetFolderByName(folderNameCrypt, http.StatusOK) assert.NoError(t, err) assert.Greater(t, folderGet.UsedQuotaSize, int64(39765+98751)) assert.Equal(t, 2, folderGet.UsedQuotaFiles) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPathCrypt) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) err = os.RemoveAll(mappedPathNested) assert.NoError(t, err) } func TestBufferedUser(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.QuotaFiles = 1000 u.FsConfig.OSConfig = sdk.OSFsConfig{ WriteBufferSize: 2, ReadBufferSize: 1, } vdirPath := "/crypted" mappedPath := filepath.Join(os.TempDir(), util.GenerateUniqueID()) folderName := filepath.Base(mappedPath) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, QuotaFiles: -1, QuotaSize: -1, }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ OSFsConfig: sdk.OSFsConfig{ WriteBufferSize: 3, ReadBufferSize: 2, }, Passphrase: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() expectedQuotaSize := int64(0) expectedQuotaFiles := 0 fileSize := int64(32768) err = writeSFTPFile(testFileName, fileSize, client) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ err = writeSFTPFile(path.Join(vdirPath, testFileName), fileSize, client) assert.NoError(t, err) expectedQuotaSize += fileSize expectedQuotaFiles++ user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Greater(t, user.UsedQuotaSize, expectedQuotaSize) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, fileSize, client) assert.NoError(t, err) err = sftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, fileSize, client) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) err = client.Remove(path.Join(vdirPath, testFileName)) assert.NoError(t, err) data := []byte("test data") f, err := client.OpenFile(testFileName, os.O_WRONLY|os.O_CREATE) if assert.NoError(t, err) { n, err := f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Truncate(2) assert.NoError(t, err) expectedQuotaSize := int64(2) expectedQuotaFiles := 0 user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) _, err = f.Seek(expectedQuotaSize, io.SeekStart) assert.NoError(t, err) n, err = f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Truncate(5) assert.NoError(t, err) expectedQuotaSize = int64(5) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) _, err = f.Seek(expectedQuotaSize, io.SeekStart) assert.NoError(t, err) n, err = f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Close() assert.NoError(t, err) expectedQuotaSize = int64(5) + int64(len(data)) expectedQuotaFiles = 1 user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) } // now truncate by path err = client.Truncate(testFileName, 5) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(5), user.UsedQuotaSize) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestTruncateQuotaLimits(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.QuotaSize = 20 mappedPath := filepath.Join(os.TempDir(), "mapped") folderName := filepath.Base(mappedPath) err := os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) vdirPath := "/vmapped" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, QuotaFiles: 10, }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err = httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.QuotaSize = 20 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() data := []byte("test data") f, err := client.OpenFile(testFileName, os.O_WRONLY|os.O_CREATE) if assert.NoError(t, err) { n, err := f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Truncate(2) assert.NoError(t, err) expectedQuotaFiles := 0 expectedQuotaSize := int64(2) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) _, err = f.Seek(expectedQuotaSize, io.SeekStart) assert.NoError(t, err) n, err = f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Truncate(5) assert.NoError(t, err) expectedQuotaSize = int64(5) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) _, err = f.Seek(expectedQuotaSize, io.SeekStart) assert.NoError(t, err) n, err = f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Close() assert.NoError(t, err) expectedQuotaFiles = 1 expectedQuotaSize = int64(5) + int64(len(data)) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) } // now truncate by path err = client.Truncate(testFileName, 5) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(5), user.UsedQuotaSize) // now open an existing file without truncate it, quota should not change f, err = client.OpenFile(testFileName, os.O_WRONLY) if assert.NoError(t, err) { err = f.Close() assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(5), user.UsedQuotaSize) } // open the file truncating it f, err = client.OpenFile(testFileName, os.O_WRONLY|os.O_TRUNC) if assert.NoError(t, err) { err = f.Close() assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) } // now test max write size f, err = client.OpenFile(testFileName, os.O_WRONLY) if assert.NoError(t, err) { n, err := f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Truncate(11) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(11), user.UsedQuotaSize) _, err = f.Seek(int64(11), io.SeekStart) assert.NoError(t, err) n, err = f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Truncate(5) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(5), user.UsedQuotaSize) _, err = f.Seek(int64(5), io.SeekStart) assert.NoError(t, err) n, err = f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Truncate(12) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(12), user.UsedQuotaSize) _, err = f.Seek(int64(12), io.SeekStart) assert.NoError(t, err) _, err = f.Write(data) if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) } err = f.Close() assert.Error(t, err) // the file is deleted user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) } if user.Username == defaultUsername { // basic test inside a virtual folder vfileName := path.Join(vdirPath, testFileName) f, err = client.OpenFile(vfileName, os.O_WRONLY|os.O_CREATE) if assert.NoError(t, err) { n, err := f.Write(data) assert.NoError(t, err) assert.Equal(t, len(data), n) err = f.Truncate(2) assert.NoError(t, err) expectedQuotaFiles := 0 expectedQuotaSize := int64(2) fold, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) err = f.Close() assert.NoError(t, err) expectedQuotaFiles = 1 fold, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) } err = client.Truncate(vfileName, 1) assert.NoError(t, err) fold, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(1), fold.UsedQuotaSize) assert.Equal(t, 1, fold.UsedQuotaFiles) // cleanup err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 user.QuotaSize = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) { usePubKey := true testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize1 := int64(65537) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err := createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) u := getTestUser(usePubKey) u.QuotaFiles = 0 u.QuotaSize = 0 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" mappedPath3 := filepath.Join(os.TempDir(), "vdir3") folderName3 := filepath.Base(mappedPath3) vdirPath3 := "/vdir3" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: 2, QuotaSize: 0, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 0, QuotaSize: testFileSize + testFileSize1 + 1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName3, }, VirtualPath: vdirPath3, QuotaFiles: 2, QuotaSize: testFileSize * 2, }) err = os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath3, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) f3 := vfs.BaseVirtualFolder{ Name: folderName3, MappedPath: mappedPath3, } _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath3, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath3, testFileName+"1"), testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName, path.Join(vdirPath1, testFileName+".rename")) assert.Error(t, err) // we overwrite an existing file and we have unlimited size err = client.Rename(testFileName, path.Join(vdirPath1, testFileName)) assert.NoError(t, err) // we have no space and we try to overwrite a bigger file with a smaller one, this should succeed err = client.Rename(testFileName1, path.Join(vdirPath2, testFileName)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) // we have no space and we try to overwrite a smaller file with a bigger one, this should fail err = client.Rename(testFileName, path.Join(vdirPath2, testFileName1)) assert.Error(t, err) fi, err := client.Stat(path.Join(vdirPath1, testFileName1)) if assert.NoError(t, err) { assert.Equal(t, testFileSize1, fi.Size()) } // we are overquota inside vdir3 size 2/2 and size 262144/262144 err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName1+".rename")) assert.Error(t, err) // we overwrite an existing file and we have enough size err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName)) assert.NoError(t, err) testFileName2 := "test_file2.dat" testFilePath2 := filepath.Join(homeBasePath, testFileName2) err = createTestFile(testFilePath2, testFileSize+testFileSize1) assert.NoError(t, err) err = sftpUploadFile(testFilePath2, testFileName2, testFileSize+testFileSize1, client) assert.NoError(t, err) // we overwrite an existing file and we haven't enough size err = client.Rename(testFileName2, path.Join(vdirPath3, testFileName)) assert.Error(t, err) err = os.Remove(testFilePath2) assert.NoError(t, err) // now remove a file from vdir3, create a dir with 2 files and try to rename it in vdir3 // this will fail since the rename will result in 3 files inside vdir3 and quota limits only // allow 2 total files there err = client.Remove(path.Join(vdirPath3, testFileName+"1")) assert.NoError(t, err) aDir := "a dir" err = client.Mkdir(aDir) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(aDir, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(aDir, testFileName1+"1"), testFileSize1, client) assert.NoError(t, err) err = client.Rename(aDir, path.Join(vdirPath3, aDir)) assert.Error(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName3}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) err = os.RemoveAll(mappedPath3) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) } func TestVirtualFoldersQuotaValues(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") vdirPath1 := "/vdir1" //nolint:goconst folderName1 := filepath.Base(mappedPath1) mappedPath2 := filepath.Join(os.TempDir(), "vdir2") vdirPath2 := "/vdir2" //nolint:goconst folderName2 := filepath.Base(mappedPath2) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) err := os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) // we copy the same file two times to test quota update on file overwrite err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) expectedQuotaFiles := 2 expectedQuotaSize := testFileSize * 2 user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = client.Remove(path.Join(vdirPath1, testFileName)) assert.NoError(t, err) err = client.Remove(path.Join(vdirPath2, testFileName)) assert.NoError(t, err) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") vdirPath1 := "/vdir1" folderName1 := filepath.Base(mappedPath1) mappedPath2 := filepath.Join(os.TempDir(), "vdir2") vdirPath2 := "/vdir2" folderName2 := filepath.Base(mappedPath2) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) err = os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName1 := "test_file1.dat" testFileSize := int64(131072) testFileSize1 := int64(65535) testFilePath := filepath.Join(homeBasePath, testFileName) testFilePath1 := filepath.Join(homeBasePath, testFileName1) dir1 := "dir1" //nolint:goconst dir2 := "dir2" //nolint:goconst err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) // initial files: // - vdir1/dir1/testFileName // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 // // rename a file inside vdir1 it is included inside user quota, so we have: // - vdir1/dir1/testFileName.rename // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath1, dir1, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // rename a file inside vdir2, it isn't included inside user quota, so we have: // - vdir1/dir1/testFileName.rename // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName.rename // - vdir2/dir2/testFileName1 err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath2, dir1, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // rename a file inside vdir2 overwriting an existing, we now have: // - vdir1/dir1/testFileName.rename // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName.rename (initial testFileName1) err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // rename a file inside vdir1 overwriting an existing, we now have: // - vdir1/dir1/testFileName.rename (initial testFileName1) // - vdir2/dir1/testFileName.rename (initial testFileName1) err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath1, dir1, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // rename a directory inside the same virtual folder, quota should not change err = client.RemoveDirectory(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.RemoveDirectory(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = client.Rename(path.Join(vdirPath1, dir1), path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Rename(path.Join(vdirPath2, dir1), path.Join(vdirPath2, dir2)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestQuotaRenameBetweenVirtualFolder(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) err := os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName1 := "test_file1.dat" testFileSize := int64(131072) testFileSize1 := int64(65535) testFilePath := filepath.Join(homeBasePath, testFileName) testFilePath1 := filepath.Join(homeBasePath, testFileName1) dir1 := "dir1" dir2 := "dir2" err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) // initial files: // - vdir1/dir1/testFileName // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 // // rename a file from vdir1 to vdir2, vdir1 is included inside user quota, so we have: // - vdir1/dir1/testFileName // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 // - vdir2/dir1/testFileName1.rename err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName1+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 3, f.UsedQuotaFiles) // rename a file from vdir2 to vdir1, vdir2 is not included inside user quota, so we have: // - vdir1/dir1/testFileName // - vdir1/dir2/testFileName.rename // - vdir2/dir2/testFileName1 // - vdir2/dir1/testFileName1.rename err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath1, dir2, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1*2, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) // rename a file from vdir1 to vdir2 overwriting an existing file, vdir1 is included inside user quota, so we have: // - vdir1/dir2/testFileName.rename // - vdir2/dir2/testFileName1 (is the initial testFileName) // - vdir2/dir1/testFileName1.rename err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath2, dir2, testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) // rename a file from vdir2 to vdir1 overwriting an existing file, vdir2 is not included inside user quota, so we have: // - vdir1/dir2/testFileName.rename (is the initial testFileName1) // - vdir2/dir2/testFileName1 (is the initial testFileName) err = client.Rename(path.Join(vdirPath2, dir1, testFileName1+".rename"), path.Join(vdirPath1, dir2, testFileName+".rename")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir2, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName+"1.dupl"), testFileSize1, client) assert.NoError(t, err) err = client.RemoveDirectory(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.RemoveDirectory(path.Join(vdirPath2, dir1)) assert.NoError(t, err) // - vdir1/dir2/testFileName.rename (initial testFileName1) // - vdir1/dir2/testFileName // - vdir2/dir2/testFileName1 (initial testFileName) // - vdir2/dir2/testFileName (initial testFileName1) // - vdir2/dir2/testFileName1.dupl // rename directories between the two virtual folders err = client.Rename(path.Join(vdirPath2, dir2), path.Join(vdirPath1, dir1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 5, user.UsedQuotaFiles) assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // now move on vpath2 err = client.Rename(path.Join(vdirPath1, dir2), path.Join(vdirPath2, dir1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestQuotaRenameFromVirtualFolder(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) err := os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName1 := "test_file1.dat" testFileSize := int64(131072) testFileSize1 := int64(65535) testFilePath := filepath.Join(homeBasePath, testFileName) testFilePath1 := filepath.Join(homeBasePath, testFileName1) dir1 := "dir1" dir2 := "dir2" err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) assert.NoError(t, err) // initial files: // - vdir1/dir1/testFileName // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 // // rename a file from vdir1 to the user home dir, vdir1 is included in user quota so we have: // - testFileName // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName // - vdir2/dir2/testFileName1 err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(testFileName)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) // rename a file from vdir2 to the user home dir, vdir2 is not included in user quota so we have: // - testFileName // - testFileName1 // - vdir1/dir2/testFileName1 // - vdir2/dir1/testFileName err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // rename a file from vdir1 to the user home dir overwriting an existing file, vdir1 is included in user quota so we have: // - testFileName (initial testFileName1) // - testFileName1 // - vdir2/dir1/testFileName err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(testFileName)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize1+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // rename a file from vdir2 to the user home dir overwriting an existing file, vdir2 is not included in user quota so we have: // - testFileName (initial testFileName1) // - testFileName1 (initial testFileName) err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // dir rename err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, dir1, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir1, testFileName1), testFileSize1, client) assert.NoError(t, err) // - testFileName (initial testFileName1) // - testFileName1 (initial testFileName) // - vdir1/dir1/testFileName // - vdir1/dir1/testFileName1 // - dir1/testFileName // - dir1/testFileName1 err = client.Rename(path.Join(vdirPath2, dir1), dir1) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 6, user.UsedQuotaFiles) assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // - testFileName (initial testFileName1) // - testFileName1 (initial testFileName) // - dir2/testFileName // - dir2/testFileName1 // - dir1/testFileName // - dir1/testFileName1 err = client.Rename(path.Join(vdirPath1, dir1), dir2) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 6, user.UsedQuotaFiles) assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestQuotaRenameToVirtualFolder(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) err := os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName1 := "test_file1.dat" testFileSize := int64(131072) testFileSize1 := int64(65535) testFilePath := filepath.Join(homeBasePath, testFileName) testFilePath1 := filepath.Join(homeBasePath, testFileName1) dir1 := "dir1" dir2 := "dir2" err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, dir2)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, dir2)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) assert.NoError(t, err) // initial files: // - testFileName // - testFileName1 // // rename a file from user home dir to vdir1, vdir1 is included in user quota so we have: // - testFileName // - /vdir1/dir1/testFileName1 err = client.Rename(testFileName1, path.Join(vdirPath1, dir1, testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) // rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have: // - /vdir2/dir1/testFileName // - /vdir1/dir1/testFileName1 err = client.Rename(testFileName, path.Join(vdirPath2, dir1, testFileName)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // upload two new files to the user home dir so we have: // - testFileName // - testFileName1 // - /vdir1/dir1/testFileName1 // - /vdir2/dir1/testFileName err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) // rename a file from user home dir to vdir1 overwriting an existing file, vdir1 is included in user quota so we have: // - testFileName1 // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName err = client.Rename(testFileName, path.Join(vdirPath1, dir1, testFileName1)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // rename a file from user home dir to vdir2 overwriting an existing file, vdir2 is not included in user quota so we have: // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName (initial testFileName1) err = client.Rename(testFileName1, path.Join(vdirPath2, dir1, testFileName)) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = client.Mkdir(dir1) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(dir1, testFileName1), testFileSize1, client) assert.NoError(t, err) // - /dir1/testFileName // - /dir1/testFileName1 // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName (initial testFileName1) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) // - /vdir1/adir/testFileName // - /vdir1/adir/testFileName1 // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName (initial testFileName1) err = client.Rename(dir1, path.Join(vdirPath1, "adir")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) err = client.Mkdir(dir1) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(dir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(dir1, testFileName1), testFileSize1, client) assert.NoError(t, err) // - /vdir1/adir/testFileName // - /vdir1/adir/testFileName1 // - /vdir1/dir1/testFileName1 (initial testFileName) // - /vdir2/dir1/testFileName (initial testFileName1) // - /vdir2/adir/testFileName // - /vdir2/adir/testFileName1 err = client.Rename(dir1, path.Join(vdirPath2, "adir")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize) assert.Equal(t, 3, f.UsedQuotaFiles) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestVirtualFoldersLink(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, // quota is unlimited and excluded from user's one QuotaFiles: 0, QuotaSize: 0, }) err := os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) testDir := "adir" err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, testDir)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, testDir)) assert.NoError(t, err) err = client.Symlink(testFileName, testFileName+".link") assert.NoError(t, err) err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testFileName+".link")) assert.NoError(t, err) err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testDir, testFileName+".link")) assert.NoError(t, err) err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+".link")) assert.NoError(t, err) err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testDir, testFileName+".link")) assert.NoError(t, err) err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath1, testFileName+".link1")) //nolint:goconst assert.Error(t, err) err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath1, testDir, testFileName+".link1")) assert.Error(t, err) err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath2, testFileName+".link1")) assert.Error(t, err) err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) assert.Error(t, err) err = client.Symlink(path.Join(vdirPath1, testFileName), testFileName+".link1") assert.Error(t, err) err = client.Symlink(path.Join(vdirPath2, testFileName), testFileName+".link1") assert.Error(t, err) err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) assert.Error(t, err) err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath1, testFileName+".link1")) assert.Error(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestVirtualFolderQuotaScan(t *testing.T) { mappedPath := filepath.Join(os.TempDir(), "mapped_dir") folderName := filepath.Base(mappedPath) err := os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) testFileSize := int64(65535) testFilePath := filepath.Join(mappedPath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) expectedQuotaSize := testFileSize expectedQuotaFiles := 1 folder, _, err := httpdtest.AddFolder(vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, }, http.StatusCreated) assert.NoError(t, err) _, err = httpdtest.StartFolderQuotaScan(folder, http.StatusAccepted) assert.NoError(t, err) assert.Eventually(t, func() bool { scans, _, err := httpdtest.GetFoldersQuotaScans(http.StatusOK) if err == nil { return len(scans) == 0 } return false }, 1*time.Second, 50*time.Millisecond) folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, folder.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, folder.UsedQuotaSize) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestVFolderMultipleQuotaScan(t *testing.T) { folderName := "folder_name" res := common.QuotaScans.AddVFolderQuotaScan(folderName) assert.True(t, res) res = common.QuotaScans.AddVFolderQuotaScan(folderName) assert.False(t, res) res = common.QuotaScans.RemoveVFolderQuotaScan(folderName) assert.True(t, res) activeScans := common.QuotaScans.GetVFoldersQuotaScans() assert.Len(t, activeScans, 0) res = common.QuotaScans.RemoveVFolderQuotaScan(folderName) assert.False(t, res) } func TestVFolderQuotaSize(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) testFileSize := int64(131072) u.QuotaFiles = 1 u.QuotaSize = testFileSize + 1 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vpath1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vpath2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, // quota is included in the user's one QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 1, QuotaSize: testFileSize * 2, }) err := os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) // vdir1 is included in the user quota so upload must fail err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) assert.Error(t, err) // upload to vdir2 must work, it has its own quota err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) // now vdir2 is over quota err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName+".quota"), testFileSize, client) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) // remove a file err = client.Remove(testFileName) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) // upload to vdir1 must work now err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, f.UsedQuotaSize) assert.Equal(t, 1, f.UsedQuotaFiles) } // now create another user with the same shared folder but a different quota limit u.Username = defaultUsername + "1" u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, }, VirtualPath: vdirPath2, QuotaFiles: 10, QuotaSize: testFileSize*2 + 1, }) user1, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err = getSftpClient(user1, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName+".quota"), testFileSize, client) assert.NoError(t, err) // the folder is now over quota for size but not for files err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName+".quota1"), testFileSize, client) assert.Error(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestMissingFile(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile("missing_file", localDownloadPath, 0, client) assert.Error(t, err, "download missing file must fail") err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestOpenError(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } usePubKey := false u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = os.Chmod(user.GetHomeDir(), 0001) assert.NoError(t, err) _, err = client.ReadDir(".") assert.Error(t, err, "read dir must fail if we have no filesystem read permissions") err = os.Chmod(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) testFileSize := int64(65535) testFilePath := filepath.Join(user.GetHomeDir(), testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) _, err = client.Stat(testFileName) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) err = os.Chmod(testFilePath, 0001) assert.NoError(t, err) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err, "file download must fail if we have no filesystem read permissions") err = sftpUploadFile(localDownloadPath, testFileName, testFileSize, client) assert.Error(t, err, "upload must fail if we have no filesystem write permissions") testDir := "test" err = client.Mkdir(testDir) assert.NoError(t, err) err = createTestFile(filepath.Join(user.GetHomeDir(), testDir, testFileName), testFileSize) assert.NoError(t, err) err = os.Chmod(user.GetHomeDir(), 0000) assert.NoError(t, err) _, err = client.Lstat(testFileName) assert.Error(t, err, "file stat must fail if we have no filesystem read permissions") err = sftpUploadFile(localDownloadPath, path.Join(testDir, testFileName), testFileSize, client) assert.ErrorIs(t, err, os.ErrPermission) _, err = client.ReadLink(testFileName) assert.ErrorIs(t, err, os.ErrPermission) err = client.Remove(testFileName) assert.ErrorIs(t, err, os.ErrPermission) err = os.Chmod(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir), 0000) assert.NoError(t, err) err = client.Rename(testFileName, path.Join(testDir, testFileName)) assert.True(t, errors.Is(err, fs.ErrPermission)) err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir), os.ModePerm) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestOverwriteDirWithFile(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(65535) testDirName := "test_dir" //nolint:goconst testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = client.Mkdir(testDirName) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testDirName, testFileSize, client) assert.Error(t, err, "copying a file over an existing dir must fail") err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName, testDirName) assert.Error(t, err, "rename a file over an existing dir must fail") err = client.RemoveDirectory(testDirName) assert.NoError(t, err) err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestHashedPasswords(t *testing.T) { usePubKey := false plainPwd := "password" pwdMapping := make(map[string]string) pwdMapping["$argon2id$v=19$m=65536,t=3,p=2$xtcO/oRkC8O2Tn+mryl2mw$O7bn24f2kuSGRMi9s5Cm61Wqd810px1jDsAasrGWkzQ"] = plainPwd pwdMapping["$pbkdf2-sha1$150000$DveVjgYUD05R$X6ydQZdyMeOvpgND2nqGR/0GGic="] = plainPwd pwdMapping["$pbkdf2-sha256$150000$E86a9YMX3zC7$R5J62hsSq+pYw00hLLPKBbcGXmq7fj5+/M0IFoYtZbo="] = plainPwd pwdMapping["$pbkdf2-sha512$150000$dsu7T5R3IaVQ$1hFXPO1ntRBcoWkSLKw+s4sAP09Xtu4Ya7CyxFq64jM9zdUg8eRJVr3NcR2vQgb0W9HHvZaILHsL4Q/Vr6arCg=="] = plainPwd pwdMapping["$1$b5caebda$VODr/nyhGWgZaY8sJ4x05."] = plainPwd pwdMapping["$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK"] = "secret" pwdMapping["$6$459ead56b72e44bc$uog86fUxscjt28BZxqFBE2pp2QD8P/1e98MNF75Z9xJfQvOckZnQ/1YJqiq1XeytPuDieHZvDAMoP7352ELkO1"] = "secret" pwdMapping["$5$h4Aalt0fJdGX8sgv$Rd2ew0fvgzUN.DzAVlKa9QL4q/DZWo4SsKpB9.3AyZ/"] = plainPwd pwdMapping["$apr1$OBWLeSme$WoJbB736e7kKxMBIAqilb1"] = plainPwd pwdMapping["{MD5}5f4dcc3b5aa765d61d8327deb882cf99"] = plainPwd pwdMapping["{SHA256}5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8"] = plainPwd pwdMapping["{SHA512}b109f3bbbc244eb82441917ed06d618b9008dd09b3befd1b5e07394c706a8bb980b1d7785e5976ec049b46df5f1326af5a2ea6d103fd07c95385ffab0cacbc86"] = plainPwd for pwd, clearPwd := range pwdMapping { u := getTestUser(usePubKey) u.Password = pwd user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Password = "" userGetInitial, _, err := httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.Equal(t, pwd, user.Password) user.Password = clearPwd conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err, "unable to login with password %q", pwd) { assert.NoError(t, checkBasicSFTP(client)) conn.Close() client.Close() } user.Password = pwd conn, client, err = getSftpClient(user, usePubKey) if !assert.Error(t, err, "login with wrong password must fail") { client.Close() conn.Close() } // the password must converted to bcrypt and we should still be able to login user, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) assert.True(t, strings.HasPrefix(user.Password, "$2a$")) // update the user to invalidate the cached password and force a new check user.Password = "" userGet, _, err := httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) userGetInitial.LastLogin = userGet.LastLogin userGetInitial.UpdatedAt = userGet.UpdatedAt assert.Equal(t, userGetInitial, userGet) // login should still work user.Password = clearPwd conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err, "unable to login with password %q", pwd) { assert.NoError(t, checkBasicSFTP(client)) conn.Close() client.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } } func TestPasswordsHashPbkdf2Sha256_389DS(t *testing.T) { pbkdf389dsPwd := "{PBKDF2_SHA256}AAAIAMZIKG4ie44zJY4HOXI+upFR74PzWLUQV63jg+zzkbEjCK3N4qW583WF7EdcpeoOMQ4HY3aWEXB6lnXhXJixbJkU4vVSJkL6YCbU3TrD0qn1uUUVSkaIgAOtmZENitwbhYhiWfEzGyAtFqkFd75P5xhWJEog9XhQKYrR0f7S3WGGZq03JRcLJ460xpU97bE/sWRn7sshgkWzLuyrs0I+XRKmK7FJeaA9zd+1m44Y3IVmZ2YLdKATzjRHAIgpBC6i1TWOcpKJT1+feP1C9hrxH8vU9baw9thNiO8jSHaZlwb//KpJFe0ahVnG/1ubiG8cO0+CCqDqXVJR6Vr4QZxHP+4pwooW+4TP/L+HFdyA1y6z4gKfqYnBsmb3sD1R1TbxfH4btTdvgZAnBk9CmR3QASkFXxeTYsrmNd5+9IAHc6dm" pbkdf389dsPwd = pbkdf389dsPwd[15:] hashBytes, err := base64.StdEncoding.DecodeString(pbkdf389dsPwd) assert.NoError(t, err) iterBytes := hashBytes[0:4] var iterations int32 err = binary.Read(bytes.NewBuffer(iterBytes), binary.BigEndian, &iterations) assert.NoError(t, err) salt := hashBytes[4:68] targetKey := hashBytes[68:] key := base64.StdEncoding.EncodeToString(targetKey) pbkdf2Pwd := fmt.Sprintf("$pbkdf2-b64salt-sha256$%v$%v$%v", iterations, base64.StdEncoding.EncodeToString(salt), key) pbkdf2ClearPwd := "password" usePubKey := false u := getTestUser(usePubKey) u.Password = pbkdf2Pwd user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Password = pbkdf2ClearPwd conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() assert.NoError(t, checkBasicSFTP(client)) } user.Password = pbkdf2Pwd conn, client, err = getSftpClient(user, usePubKey) if !assert.Error(t, err, "login with wrong password must fail") { client.Close() conn.Close() } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermList(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown, dataprovider.PermChtimes} u.Permissions["/sub"] = []string{dataprovider.PermCreateSymlinks, dataprovider.PermListItems} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() _, err = client.ReadDir(".") assert.ErrorIs(t, err, os.ErrPermission, "read remote dir without permission should not succeed") _, err = client.Stat("test_file") assert.ErrorIs(t, err, os.ErrPermission, "stat remote file without permission should not succeed") _, err = client.Lstat("test_file") assert.ErrorIs(t, err, os.ErrPermission, "lstat remote file without permission should not succeed") _, err = client.ReadLink("test_link") assert.ErrorIs(t, err, os.ErrPermission, "read remote link without permission on source dir should not succeed") _, err = client.RealPath(".") assert.ErrorIs(t, err, os.ErrPermission, "real path without permission should not succeed") f, err := client.Create(testFileName) if assert.NoError(t, err) { _, err = f.Write([]byte("content")) assert.NoError(t, err) err = f.Close() assert.NoError(t, err) } err = client.Mkdir("sub") assert.NoError(t, err) err = client.Symlink(path.Join("/", testFileName), path.Join("/sub", testFileName)) assert.NoError(t, err) _, err = client.ReadLink(path.Join("/sub", testFileName)) assert.Error(t, err, "read remote link without permission on targe dir should not succeed") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermDownload(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err, "file download without permission should not succeed") err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermUpload(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.Error(t, err, "file upload without permission should not succeed") err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermOverwrite(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermChmod, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.Error(t, err, "file overwrite without permission should not succeed") err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermDelete(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Remove(testFileName) assert.Error(t, err, "delete without permission should not succeed") err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } //nolint:dupl func TestPermRename(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName, testFileName+".rename") assert.True(t, errors.Is(err, fs.ErrPermission)) _, err = client.Stat(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } //nolint:dupl func TestPermRenameOverwrite(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermChmod, dataprovider.PermRename, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName, testFileName+".rename") assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName, testFileName+".rename") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermCreateDirs(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("testdir") assert.Error(t, err, "mkdir without permission should not succeed") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } //nolint:dupl func TestPermSymlink(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Symlink(testFilePath, testFilePath+".symlink") assert.Error(t, err, "symlink without permission should not succeed") err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermChmod(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChown, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Chmod(testFileName, os.ModePerm) assert.Error(t, err, "chmod without permission should not succeed") err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } //nolint:dupl func TestPermChown(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChtimes} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Chown(testFileName, os.Getuid(), os.Getgid()) assert.Error(t, err, "chown without permission should not succeed") err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } //nolint:dupl func TestPermChtimes(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Chtimes(testFileName, time.Now(), time.Now()) assert.Error(t, err, "chtimes without permission should not succeed") err = client.Remove(testFileName) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSubDirsUploads(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermAny} u.Permissions["/subdir"] = []string{dataprovider.PermChtimes, dataprovider.PermDownload, dataprovider.PermOverwrite} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("subdir") assert.NoError(t, err) testFileNameSub := "/subdir/test_file_dat" testSubFile := filepath.Join(user.GetHomeDir(), "subdir", "file.dat") testDir := "testdir" testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testSubFile, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileNameSub, testFileSize, client) assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Symlink(testFileName, testFileNameSub+".link") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Symlink(testFileName, testFileName+".link") assert.NoError(t, err) err = client.Rename(testFileName, testFileNameSub+".rename") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Rename(testFileName, testFileName+".rename") assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) // rename overwriting an existing file err = client.Rename(testFileName, testFileName+".rename") assert.NoError(t, err) // now try to overwrite a directory err = client.Mkdir(testDir) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Rename(testFileName, testDir) assert.Error(t, err) err = client.Remove(testFileName) assert.NoError(t, err) err = client.Remove(testDir) assert.NoError(t, err) err = client.Remove(path.Join("/subdir", "file.dat")) assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Remove(testFileName + ".rename") assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSubDirsOverwrite(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermAny} u.Permissions["/subdir"] = []string{dataprovider.PermOverwrite, dataprovider.PermListItems} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileName := "/subdir/test_file.dat" //nolint:goconst testFilePath := filepath.Join(homeBasePath, "test_file.dat") testFileSFTPPath := filepath.Join(u.GetHomeDir(), "subdir", "test_file.dat") testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFileSFTPPath, 16384) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName+".new", testFileSize, client) assert.True(t, errors.Is(err, fs.ErrPermission)) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSubDirsDownloads(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermAny} u.Permissions["/subdir"] = []string{dataprovider.PermChmod, dataprovider.PermUpload, dataprovider.PermListItems} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("subdir") assert.NoError(t, err) testFileName := "/subdir/test_file.dat" testFilePath := filepath.Join(homeBasePath, "test_file.dat") testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) assert.True(t, errors.Is(err, fs.ErrPermission)) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Chtimes(testFileName, time.Now(), time.Now()) assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Rename(testFileName, testFileName+".rename") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Symlink(testFileName, testFileName+".link") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Remove(testFileName) assert.True(t, errors.Is(err, fs.ErrPermission)) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermsSubDirsSetstat(t *testing.T) { // for setstat we check the parent dir permission if the requested path is a dir // otherwise the path permission usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermCreateDirs} u.Permissions["/subdir"] = []string{dataprovider.PermAny} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("subdir") assert.NoError(t, err) testFileName := "/subdir/test_file.dat" testFilePath := filepath.Join(homeBasePath, "test_file.dat") testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = client.Chtimes("/subdir/", time.Now(), time.Now()) assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Chtimes("subdir/", time.Now(), time.Now()) assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Chtimes(testFileName, time.Now(), time.Now()) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestOpenUnhandledChannel(t *testing.T) { u := getTestUser(false) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)}, Timeout: 5 * time.Second, } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if assert.NoError(t, err) { _, _, err = conn.OpenChannel("unhandled", nil) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unknown channel type") } err = conn.Close() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestAlgorithmNotNegotiated(t *testing.T) { u := getTestUser(false) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) config := &ssh.ClientConfig{ Config: ssh.Config{ Ciphers: []string{ssh.InsecureCipherRC4}, }, User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)}, Timeout: 5 * time.Second, } _, err = ssh.Dial("tcp", sftpServerAddr, config) if assert.Error(t, err) { negotiationErr := &ssh.AlgorithmNegotiationError{} assert.ErrorAs(t, err, &negotiationErr) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPermsSubDirsCommands(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermAny} u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs} u.Permissions["/subdir/otherdir"] = []string{dataprovider.PermListItems, dataprovider.PermDownload} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Mkdir("subdir") assert.NoError(t, err) acmodTime := time.Now() err = client.Chtimes("/subdir", acmodTime, acmodTime) assert.NoError(t, err) _, err = client.Stat("/subdir") assert.NoError(t, err) _, err = client.ReadDir("/") assert.NoError(t, err) _, err = client.ReadDir("/subdir") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.RemoveDirectory("/subdir/dir") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Mkdir("/subdir/otherdir/dir") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Mkdir("/otherdir") assert.NoError(t, err) err = client.Mkdir("/subdir/otherdir") assert.NoError(t, err) err = client.Rename("/otherdir", "/subdir/otherdir/adir") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Symlink("/otherdir", "/subdir/otherdir") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Symlink("/otherdir", "/otherdir_link") assert.NoError(t, err) err = client.Rename("/otherdir", "/otherdir1") assert.NoError(t, err) err = client.RemoveDirectory("/otherdir1") assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRootDirCommands(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermAny} u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermAny} u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Rename("/", "rootdir") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.Symlink("/", "rootdir") assert.True(t, errors.Is(err, fs.ErrPermission)) err = client.RemoveDirectory("/") assert.True(t, errors.Is(err, fs.ErrPermission)) } if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestRelativePaths(t *testing.T) { user := getTestUser(true) var path, rel string filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), "", nil)} keyPrefix := strings.TrimPrefix(user.GetHomeDir(), "/") + "/" s3config := vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ KeyPrefix: keyPrefix, }, } s3fs, _ := vfs.NewS3Fs("", user.GetHomeDir(), "", s3config) gcsConfig := vfs.GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ KeyPrefix: keyPrefix, }, } gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), "", gcsConfig) sftpconfig := vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: defaultUsername, Prefix: keyPrefix, }, Password: kms.NewPlainSecret(defaultPassword), } sftpfs, _ := vfs.NewSFTPFs("", "", os.TempDir(), []string{user.Username}, sftpconfig) if runtime.GOOS != osWindows { filesystems = append(filesystems, s3fs, gcsfs, sftpfs) } rootPath := "/" for _, fs := range filesystems { path = filepath.Join(user.HomeDir, "/") rel = fs.GetRelativePath(path) assert.Equal(t, rootPath, rel) path = filepath.Join(user.HomeDir, "//") rel = fs.GetRelativePath(path) assert.Equal(t, rootPath, rel) path = filepath.Join(user.HomeDir, "../..") rel = fs.GetRelativePath(path) assert.Equal(t, rootPath, rel) path = filepath.Join(user.HomeDir, "../../../../../") rel = fs.GetRelativePath(path) assert.Equal(t, rootPath, rel) path = filepath.Join(user.HomeDir, "/..") rel = fs.GetRelativePath(path) assert.Equal(t, rootPath, rel) path = filepath.Join(user.HomeDir, "/../../../..") rel = fs.GetRelativePath(path) assert.Equal(t, rootPath, rel) path = filepath.Join(user.HomeDir, "") rel = fs.GetRelativePath(path) assert.Equal(t, rootPath, rel) path = filepath.Join(user.HomeDir, ".") rel = fs.GetRelativePath(path) assert.Equal(t, rootPath, rel) path = filepath.Join(user.HomeDir, "somedir") rel = fs.GetRelativePath(path) assert.Equal(t, "/somedir", rel) path = filepath.Join(user.HomeDir, "/somedir/subdir") rel = fs.GetRelativePath(path) assert.Equal(t, "/somedir/subdir", rel) } } func TestResolvePaths(t *testing.T) { user := getTestUser(true) var path, resolved string var err error filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), "", nil)} keyPrefix := strings.TrimPrefix(user.GetHomeDir(), "/") + "/" s3config := vfs.S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ KeyPrefix: keyPrefix, Bucket: "bucket", Region: "us-east-1", }, } err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) s3fs, err := vfs.NewS3Fs("", user.GetHomeDir(), "", s3config) assert.NoError(t, err) gcsConfig := vfs.GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ KeyPrefix: keyPrefix, }, } gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), "", gcsConfig) if runtime.GOOS != osWindows { filesystems = append(filesystems, s3fs, gcsfs) } for _, fs := range filesystems { path = "/" resolved, _ = fs.ResolvePath(filepath.ToSlash(path)) assert.Equal(t, fs.Join(user.GetHomeDir(), "/"), resolved) path = "." resolved, _ = fs.ResolvePath(filepath.ToSlash(path)) assert.Equal(t, fs.Join(user.GetHomeDir(), "/"), resolved) path = "test/sub" resolved, _ = fs.ResolvePath(filepath.ToSlash(path)) assert.Equal(t, fs.Join(user.GetHomeDir(), "/test/sub"), resolved) path = "../test/sub" resolved, err = fs.ResolvePath(filepath.ToSlash(path)) assert.NoError(t, err) assert.Equal(t, fs.Join(user.GetHomeDir(), "/test/sub"), resolved) path = "../../../test/../sub" resolved, err = fs.ResolvePath(filepath.ToSlash(path)) assert.NoError(t, err) assert.Equal(t, fs.Join(user.GetHomeDir(), "/sub"), resolved) } err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestVirtualRelativePaths(t *testing.T) { user := getTestUser(true) mappedPath := filepath.Join(os.TempDir(), "mdir") vdirPath := "/vdir" //nolint:goconst user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath, }, VirtualPath: vdirPath, }) err := os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) fsRoot := vfs.NewOsFs("", user.GetHomeDir(), "", nil) fsVdir := vfs.NewOsFs("", mappedPath, vdirPath, nil) rel := fsVdir.GetRelativePath(mappedPath) assert.Equal(t, vdirPath, rel) rel = fsRoot.GetRelativePath(filepath.Join(mappedPath, "..")) assert.Equal(t, "/", rel) // path outside home and virtual dir rel = fsRoot.GetRelativePath(filepath.Join(mappedPath, "../vdir1")) assert.Equal(t, "/", rel) rel = fsVdir.GetRelativePath(filepath.Join(mappedPath, "../vdir1")) assert.Equal(t, "/vdir", rel) rel = fsVdir.GetRelativePath(filepath.Join(mappedPath, "file.txt")) assert.Equal(t, "/vdir/file.txt", rel) rel = fsRoot.GetRelativePath(filepath.Join(user.HomeDir, "vdir1/file.txt")) assert.Equal(t, "/vdir1/file.txt", rel) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestUserPerms(t *testing.T) { user := getTestUser(true) user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermListItems} user.Permissions["/p"] = []string{dataprovider.PermDelete} user.Permissions["/p/1"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} user.Permissions["/p/2"] = []string{dataprovider.PermCreateDirs} user.Permissions["/p/3"] = []string{dataprovider.PermChmod} user.Permissions["/p/3/4"] = []string{dataprovider.PermChtimes} user.Permissions["/tmp"] = []string{dataprovider.PermRename} assert.True(t, user.HasPerm(dataprovider.PermListItems, "/")) assert.True(t, user.HasPerm(dataprovider.PermListItems, ".")) assert.True(t, user.HasPerm(dataprovider.PermListItems, "")) assert.True(t, user.HasPerm(dataprovider.PermListItems, "../")) // path p and /p are the same assert.True(t, user.HasPerm(dataprovider.PermDelete, "/p")) assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/1")) assert.True(t, user.HasPerm(dataprovider.PermCreateDirs, "p/2")) assert.True(t, user.HasPerm(dataprovider.PermChmod, "/p/3")) assert.True(t, user.HasPerm(dataprovider.PermChtimes, "p/3/4/")) assert.True(t, user.HasPerm(dataprovider.PermChtimes, "p/3/4/../4")) // undefined paths have permissions of the nearest path assert.True(t, user.HasPerm(dataprovider.PermListItems, "/p34")) assert.True(t, user.HasPerm(dataprovider.PermListItems, "/p34/p1/file.dat")) assert.True(t, user.HasPerm(dataprovider.PermChtimes, "/p/3/4/5/6")) assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/1/test/file.dat")) } func TestWildcardPermissions(t *testing.T) { user := getTestUser(true) user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermListItems} user.Permissions["/p*"] = []string{dataprovider.PermDelete} user.Permissions["/p/*"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} user.Permissions["/p/2"] = []string{dataprovider.PermCreateDirs} user.Permissions["/pa"] = []string{dataprovider.PermChmod} user.Permissions["/p/3/4"] = []string{dataprovider.PermChtimes} assert.True(t, user.HasPerm(dataprovider.PermListItems, "/")) assert.True(t, user.HasPerm(dataprovider.PermDelete, "/p1")) assert.True(t, user.HasPerm(dataprovider.PermDelete, "/ppppp")) assert.False(t, user.HasPerm(dataprovider.PermDelete, "/pa")) assert.True(t, user.HasPerm(dataprovider.PermChmod, "/pa")) assert.True(t, user.HasPerm(dataprovider.PermUpload, "/p/1")) assert.True(t, user.HasPerm(dataprovider.PermUpload, "/p/p")) assert.False(t, user.HasPerm(dataprovider.PermUpload, "/p/2")) assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/3")) assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/a/a/a")) assert.False(t, user.HasPerm(dataprovider.PermDownload, "/p/3/4")) assert.True(t, user.HasPerm(dataprovider.PermChtimes, "/p/3/4")) assert.True(t, user.HasPerm(dataprovider.PermDelete, "/pb/a/a/a")) assert.False(t, user.HasPerm(dataprovider.PermDelete, "/abc/a/a/a")) assert.True(t, user.HasPerm(dataprovider.PermListItems, "/abc/a/a/a/b")) } func TestRootWildcardPerms(t *testing.T) { user := getTestUser(true) user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermListItems} user.Permissions["/*"] = []string{dataprovider.PermDelete} user.Permissions["/p/*"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} user.Permissions["/p/2"] = []string{dataprovider.PermCreateDirs} user.Permissions["/pa"] = []string{dataprovider.PermChmod} user.Permissions["/p/3/4"] = []string{dataprovider.PermChtimes} assert.True(t, user.HasPerm(dataprovider.PermListItems, "/")) assert.True(t, user.HasPerm(dataprovider.PermDelete, "/p1")) assert.True(t, user.HasPerm(dataprovider.PermDelete, "/ppppp")) assert.False(t, user.HasPerm(dataprovider.PermDelete, "/pa")) assert.True(t, user.HasPerm(dataprovider.PermChmod, "/pa")) assert.True(t, user.HasPerm(dataprovider.PermUpload, "/p/1")) assert.True(t, user.HasPerm(dataprovider.PermUpload, "/p/p")) assert.False(t, user.HasPerm(dataprovider.PermUpload, "/p/2")) assert.True(t, user.HasPerm(dataprovider.PermCreateDirs, "/p/2")) assert.True(t, user.HasPerm(dataprovider.PermCreateDirs, "/p/2/a")) assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/3")) assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/a/a/a")) assert.False(t, user.HasPerm(dataprovider.PermDownload, "/p/3/4")) assert.True(t, user.HasPerm(dataprovider.PermChtimes, "/p/3/4")) assert.True(t, user.HasPerm(dataprovider.PermDelete, "/pb/a/a/a")) assert.True(t, user.HasPerm(dataprovider.PermDelete, "/abc/a/a/a")) assert.False(t, user.HasPerm(dataprovider.PermListItems, "/abc/a/a/a/b")) assert.True(t, user.HasPerm(dataprovider.PermDelete, "/abc/a/a/a/b")) } func TestFilterFilePatterns(t *testing.T) { user := getTestUser(true) pattern := sdk.PatternsFilter{ Path: "/test", AllowedPatterns: []string{"*.jpg", "*.png"}, DeniedPatterns: []string{"*.pdf"}, } filters := dataprovider.UserFilters{ BaseUserFilters: sdk.BaseUserFilters{ FilePatterns: []sdk.PatternsFilter{pattern}, }, } user.Filters = filters ok, _ := user.IsFileAllowed("/test/test.jPg") assert.True(t, ok) ok, _ = user.IsFileAllowed("/test/test.pdf") assert.False(t, ok) ok, _ = user.IsFileAllowed("/test.pDf") assert.True(t, ok) filters.FilePatterns = append(filters.FilePatterns, sdk.PatternsFilter{ Path: "/", AllowedPatterns: []string{"*.zip", "*.rar", "*.pdf"}, DeniedPatterns: []string{"*.gz"}, }) user.Filters = filters ok, _ = user.IsFileAllowed("/test1/test.gz") assert.False(t, ok) ok, _ = user.IsFileAllowed("/test1/test.zip") assert.True(t, ok) ok, _ = user.IsFileAllowed("/test/sub/test.pdf") assert.False(t, ok) ok, _ = user.IsFileAllowed("/test1/test.png") assert.False(t, ok) filters.FilePatterns = append(filters.FilePatterns, sdk.PatternsFilter{ Path: "/test/sub", DeniedPatterns: []string{"*.tar"}, }) user.Filters = filters ok, _ = user.IsFileAllowed("/test/sub/sub/test.tar") assert.False(t, ok) ok, _ = user.IsFileAllowed("/test/sub/test.gz") assert.True(t, ok) ok, _ = user.IsFileAllowed("/test/test.zip") assert.False(t, ok) } func TestUserAllowedLoginMethods(t *testing.T) { user := getTestUser(true) user.Filters.DeniedLoginMethods = dataprovider.ValidLoginMethods allowedMethods := user.GetAllowedLoginMethods() assert.Equal(t, 0, len(allowedMethods)) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyboardInteractive, } allowedMethods = user.GetAllowedLoginMethods() assert.Equal(t, 4, len(allowedMethods)) assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt)) assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword)) } func TestUserPartialAuth(t *testing.T) { user := getTestUser(true) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyboardInteractive, } assert.True(t, user.IsPartialAuth()) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, } assert.False(t, user.IsPartialAuth()) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, } assert.False(t, user.IsPartialAuth()) user.Filters.DeniedLoginMethods = []string{ dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyboardInteractive, } assert.True(t, user.IsPartialAuth()) } func TestUserGetNextAuthMethods(t *testing.T) { user := getTestUser(true) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyboardInteractive, } methods := user.GetNextAuthMethods() require.Len(t, methods, 2) assert.Equal(t, dataprovider.LoginMethodPassword, methods[0]) assert.Equal(t, dataprovider.SSHLoginMethodKeyboardInteractive, methods[1]) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndKeyboardInt, } methods = user.GetNextAuthMethods() require.Len(t, methods, 1) assert.Equal(t, dataprovider.LoginMethodPassword, methods[0]) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, } methods = user.GetNextAuthMethods() require.Len(t, methods, 1) assert.Equal(t, dataprovider.SSHLoginMethodKeyboardInteractive, methods[0]) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyAndPassword, dataprovider.SSHLoginMethodKeyAndKeyboardInt, } methods = user.GetNextAuthMethods() require.Len(t, methods, 0) } func TestUserIsLoginMethodAllowed(t *testing.T) { user := getTestUser(true) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyboardInteractive, } assert.False(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolSSH)) assert.False(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolFTP)) assert.False(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolWebDAV)) assert.False(t, user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodPublicKey, common.ProtocolSSH)) assert.False(t, user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodKeyboardInteractive, common.ProtocolSSH)) user.Filters.DeniedLoginMethods = []string{ dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyboardInteractive, } assert.True(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolSSH)) user.Filters.DeniedLoginMethods = []string{ dataprovider.SSHLoginMethodPassword, } assert.True(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP)) assert.True(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolFTP)) assert.True(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolWebDAV)) assert.False(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolSSH)) } func TestUserEmptySubDirPerms(t *testing.T) { user := getTestUser(true) user.Permissions = make(map[string][]string) user.Permissions["/emptyperms"] = []string{} for _, p := range dataprovider.ValidPerms { assert.False(t, user.HasPerm(p, "/emptyperms")) } } func TestUserFiltersIPMaskConditions(t *testing.T) { user := getTestUser(true) // with no filter login must be allowed even if the remoteIP is invalid assert.True(t, user.IsLoginFromAddrAllowed("192.168.1.5")) assert.True(t, user.IsLoginFromAddrAllowed("invalid")) user.Filters.DeniedIP = append(user.Filters.DeniedIP, "192.168.1.0/24") assert.False(t, user.IsLoginFromAddrAllowed("192.168.1.5")) assert.True(t, user.IsLoginFromAddrAllowed("192.168.2.6")) user.Filters.AllowedIP = append(user.Filters.AllowedIP, "192.168.1.5/32") // if the same ip/mask is both denied and allowed then login must be allowed assert.True(t, user.IsLoginFromAddrAllowed("192.168.1.5")) assert.False(t, user.IsLoginFromAddrAllowed("192.168.1.3")) assert.False(t, user.IsLoginFromAddrAllowed("192.168.3.6")) user.Filters.DeniedIP = []string{} assert.True(t, user.IsLoginFromAddrAllowed("192.168.1.5")) assert.False(t, user.IsLoginFromAddrAllowed("192.168.1.6")) user.Filters.DeniedIP = []string{"192.168.0.0/16", "172.16.0.0/16"} user.Filters.AllowedIP = []string{} assert.False(t, user.IsLoginFromAddrAllowed("192.168.5.255")) assert.False(t, user.IsLoginFromAddrAllowed("172.16.1.2")) assert.True(t, user.IsLoginFromAddrAllowed("172.18.2.1")) user.Filters.AllowedIP = []string{"10.4.4.0/24"} assert.False(t, user.IsLoginFromAddrAllowed("10.5.4.2")) assert.True(t, user.IsLoginFromAddrAllowed("10.4.4.2")) assert.True(t, user.IsLoginFromAddrAllowed("invalid")) } func TestGetVirtualFolderForPath(t *testing.T) { user := getTestUser(true) mappedPath1 := filepath.Join(os.TempDir(), "vpath1") mappedPath2 := filepath.Join(os.TempDir(), "vpath1") mappedPath3 := filepath.Join(os.TempDir(), "vpath3") vdirPath := "/vdir/sub" vSubDirPath := path.Join(vdirPath, "subdir", "subdir") vSubDir1Path := path.Join(vSubDirPath, "subdir", "subdir") user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath1, }, VirtualPath: vdirPath, }) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath2, }, VirtualPath: vSubDir1Path, }) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: mappedPath3, }, VirtualPath: vSubDirPath, }) folder, err := user.GetVirtualFolderForPath(path.Join(vSubDirPath, "file")) assert.NoError(t, err) assert.Equal(t, folder.MappedPath, mappedPath3) _, err = user.GetVirtualFolderForPath("/file") assert.Error(t, err) folder, err = user.GetVirtualFolderForPath(path.Join(vdirPath, "/file")) assert.NoError(t, err) assert.Equal(t, folder.MappedPath, mappedPath1) folder, err = user.GetVirtualFolderForPath(path.Join(vSubDirPath+"1", "file")) assert.NoError(t, err) assert.Equal(t, folder.MappedPath, mappedPath1) _, err = user.GetVirtualFolderForPath("/vdir/sub1/file") assert.Error(t, err) folder, err = user.GetVirtualFolderForPath(vdirPath) assert.NoError(t, err) } func TestStatVFS(t *testing.T) { usePubKey := false user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) testFileSize := int64(65535) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() stat, err := client.StatVFS("/") assert.NoError(t, err) assert.Greater(t, stat.ID, uint32(0)) assert.Greater(t, stat.Blocks, uint64(0)) assert.Greater(t, stat.Bsize, uint64(0)) _, err = client.StatVFS("missing-path") assert.Error(t, err) assert.True(t, errors.Is(err, fs.ErrNotExist)) } user.QuotaFiles = 100 user.Filters.DisableFsChecks = true user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = client.StatVFS("missing-path") assert.Error(t, err) assert.ErrorIs(t, err, fs.ErrNotExist) stat, err := client.StatVFS("/") assert.NoError(t, err) assert.Greater(t, stat.ID, uint32(0)) assert.Greater(t, stat.Blocks, uint64(0)) assert.Greater(t, stat.Bsize, uint64(0)) assert.Equal(t, uint64(100), stat.Files) assert.Equal(t, uint64(99), stat.Ffree) } user.QuotaSize = 8192 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() stat, err := client.StatVFS("/") assert.NoError(t, err) assert.Greater(t, stat.ID, uint32(0)) assert.Greater(t, stat.Blocks, uint64(0)) assert.Greater(t, stat.Bsize, uint64(0)) assert.Equal(t, uint64(100), stat.Files) assert.Equal(t, uint64(0), stat.Ffree) assert.Equal(t, uint64(2), stat.Blocks) assert.Equal(t, uint64(0), stat.Bfree) } user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() stat, err := client.StatVFS("/") assert.NoError(t, err) assert.Greater(t, stat.ID, uint32(0)) assert.Greater(t, stat.Blocks, uint64(0)) assert.Greater(t, stat.Bsize, uint64(0)) assert.Greater(t, stat.Files, uint64(0)) assert.Equal(t, uint64(0), stat.Ffree) assert.Equal(t, uint64(2), stat.Blocks) assert.Equal(t, uint64(0), stat.Bfree) } user.QuotaSize = 1 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() stat, err := client.StatVFS("/") assert.NoError(t, err) assert.Greater(t, stat.ID, uint32(0)) assert.Equal(t, uint64(1), stat.Blocks) assert.Equal(t, uint64(1), stat.Bsize) assert.Greater(t, stat.Files, uint64(0)) assert.Equal(t, uint64(0), stat.Ffree) assert.Equal(t, uint64(1), stat.Blocks) assert.Equal(t, uint64(0), stat.Bfree) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestStatVFSCloudBackend(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.FsConfig.Provider = sdk.AzureBlobFilesystemProvider u.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("https://myaccount.blob.core.windows.net/sasurl") user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = dataprovider.UpdateUserQuota(&user, 100, 8192, true) assert.NoError(t, err) stat, err := client.StatVFS("/") assert.NoError(t, err) assert.Greater(t, stat.ID, uint32(0)) assert.Greater(t, stat.Blocks, uint64(0)) assert.Greater(t, stat.Bsize, uint64(0)) assert.Equal(t, uint64(1000000+100), stat.Files) assert.Equal(t, uint64(2147483648+2), stat.Blocks) assert.Equal(t, uint64(1000000), stat.Ffree) assert.Equal(t, uint64(2147483648), stat.Bfree) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSSHCommands(t *testing.T) { usePubKey := false user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) _, err = runSSHCommand("ls", user, usePubKey) assert.Error(t, err, "unsupported ssh command must fail") _, err = runSSHCommand("cd", user, usePubKey) assert.NoError(t, err) out, err := runSSHCommand("pwd", user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "/\n", string(out)) } out, err = runSSHCommand("md5sum", user, usePubKey) assert.NoError(t, err) // echo -n '' | md5sum assert.Contains(t, string(out), "d41d8cd98f00b204e9800998ecf8427e") out, err = runSSHCommand("sha1sum", user, usePubKey) assert.NoError(t, err) assert.Contains(t, string(out), "da39a3ee5e6b4b0d3255bfef95601890afd80709") out, err = runSSHCommand("sha256sum", user, usePubKey) assert.NoError(t, err) assert.Contains(t, string(out), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") out, err = runSSHCommand("sha384sum", user, usePubKey) assert.NoError(t, err) assert.Contains(t, string(out), "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSSHFileHash(t *testing.T) { usePubKey := true localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) assert.NoError(t, err) u := getTestUserWithCryptFs(usePubKey) u.Username = u.Username + "_crypt" cryptUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser, cryptUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermUpload} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = runSSHCommand("sha512sum "+testFileName, user, usePubKey) assert.Error(t, err, "hash command with no list permission must fail") user.Permissions["/"] = []string{dataprovider.PermAny} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) initialHash, err := computeHashForFile(sha512.New(), testFilePath) assert.NoError(t, err) out, err := runSSHCommand("sha512sum "+testFileName, user, usePubKey) if assert.NoError(t, err) { assert.Contains(t, string(out), initialHash) } _, err = runSSHCommand("sha512sum invalid_path", user, usePubKey) assert.Error(t, err, "hash for an invalid path must fail") err = os.Remove(testFilePath) assert.NoError(t, err) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(cryptUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestSSHCopy(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1/subdir" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2/subdir" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 100, QuotaSize: 0, }) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", DeniedPatterns: []string{"*.denied"}, }, } err := os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testDir := "adir" testDir1 := "adir1" conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) testFileName1 := "test_file1.dat" testFileSize1 := int64(65537) testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) err = client.Mkdir(testDir) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, testDir1)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, testDir1)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testDir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testDir1, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testDir1, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testDir1, testFileName1), testFileSize1, client) assert.NoError(t, err) err = client.Symlink(path.Join(testDir, testFileName), testFileName) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 4, user.UsedQuotaFiles) assert.Equal(t, 2*testFileSize+2*testFileSize1, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) assert.Equal(t, 2, f.UsedQuotaFiles) _, err = client.Stat(testDir1) assert.Error(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s", path.Join(vdirPath1, testDir1)), user, usePubKey) assert.Error(t, err) _, err = runSSHCommand("sftpgo-copy", user, usePubKey) assert.Error(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileName+".linkcopy"), user, usePubKey) assert.Error(t, err) out, err := runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join(vdirPath1, testDir1), "."), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) fi, err := client.Stat(testDir1) if assert.NoError(t, err) { assert.True(t, fi.IsDir()) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 6, user.UsedQuotaFiles) assert.Equal(t, 3*testFileSize+3*testFileSize1, user.UsedQuotaSize) } _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", "missing\\ dir", "."), user, usePubKey) assert.Error(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir1), "."), user, usePubKey) if assert.NoError(t, err) { // all files are overwritten, quota must remain unchanged user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 6, user.UsedQuotaFiles) assert.Equal(t, 3*testFileSize+3*testFileSize1, user.UsedQuotaSize) } out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, testDir1, testFileName), testFileName+".copy"), //nolint:goconst user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) fi, err := client.Stat(testFileName + ".copy") //nolint:goconst if assert.NoError(t, err) { assert.True(t, fi.Mode().IsRegular()) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 7, user.UsedQuotaFiles) assert.Equal(t, 4*testFileSize+3*testFileSize1, user.UsedQuotaSize) } out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir1), path.Join(vdirPath2, testDir1+"copy")), //nolint:goconst user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) fi, err := client.Stat(path.Join(vdirPath2, testDir1+"copy")) if assert.NoError(t, err) { assert.True(t, fi.IsDir()) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 7, user.UsedQuotaFiles) assert.Equal(t, 4*testFileSize+3*testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize*2+testFileSize1*2, f.UsedQuotaSize) assert.Equal(t, 4, f.UsedQuotaFiles) } out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir1), path.Join(vdirPath1, testDir1+"copy")), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) _, err := client.Stat(path.Join(vdirPath2, testDir1+"copy")) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 9, user.UsedQuotaFiles) assert.Equal(t, 5*testFileSize+4*testFileSize1, user.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) } // cross folder copy newDir := "newdir" _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, ".."), newDir), user, usePubKey) assert.NoError(t, err) _, err = client.Stat(newDir) assert.NoError(t, err) // denied pattern _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(testDir, testFileName), testFileName+".denied"), user, usePubKey) assert.Error(t, err) if runtime.GOOS != osWindows { subPath := filepath.Join(mappedPath1, testDir1, "asubdir", "anothersub", "another") err = os.MkdirAll(subPath, os.ModePerm) assert.NoError(t, err) err = os.Chmod(subPath, 0001) assert.NoError(t, err) // listing contents for subdirs with no permissions will fail _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", vdirPath1, "newdir1"), user, usePubKey) assert.Error(t, err) err = os.Chmod(subPath, os.ModePerm) assert.NoError(t, err) err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir1), 0555) assert.NoError(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir1, testFileName), path.Join(testDir1, "anewdir")), user, usePubKey) assert.Error(t, err) err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir1), os.ModePerm) assert.NoError(t, err) err = os.Chmod(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestSSHCopyPermissions(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.Permissions["/dir1"] = []string{dataprovider.PermUpload, dataprovider.PermDownload, dataprovider.PermListItems} u.Permissions["/dir2"] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermDownload, dataprovider.PermListItems, dataprovider.PermCopy} u.Permissions["/dir3"] = []string{dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermDownload, dataprovider.PermListItems} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testDir := "tDir" testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = client.Mkdir(testDir) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join("/", testDir, testFileName), testFileSize, client) assert.NoError(t, err) // test copy file with no permission _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir, testFileName), path.Join("/dir3", testFileName)), user, usePubKey) assert.Error(t, err) // test copy dir with no create dirs perm _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/dir1/"), user, usePubKey) assert.Error(t, err) // dir2 has the needed permissions _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/dir2/"), user, usePubKey) assert.NoError(t, err) info, err := client.Stat(path.Join("/dir2", testDir)) if assert.NoError(t, err) { assert.True(t, info.IsDir()) } info, err = client.Stat(path.Join("/dir2", testDir, testFileName)) if assert.NoError(t, err) { assert.True(t, info.Mode().IsRegular()) } // now create a symlink, dir2 has no create symlink permission, but symlinks will be ignored err = client.Symlink(path.Join("/", testDir, testFileName), path.Join("/", testDir, testFileName+".link")) assert.NoError(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/dir2/sub"), user, usePubKey) assert.NoError(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/newdir"), user, usePubKey) assert.NoError(t, err) // now delete the file and copy inside /dir3 err = client.Remove(path.Join("/", testDir, testFileName)) assert.NoError(t, err) // the symlink will be skipped, so no errors _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/dir3"), user, usePubKey) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSSHCopyQuotaLimits(t *testing.T) { usePubKey := true testFileSize := int64(131072) testFileSize1 := int64(65536) testFileSize2 := int64(32768) u := getTestUser(usePubKey) u.QuotaFiles = 3 u.QuotaSize = testFileSize + testFileSize1 + 1 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 3, QuotaSize: testFileSize + testFileSize1 + 1, }) u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", DeniedPatterns: []string{"*.denied"}, }, } err := os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testDir := "testDir" testFilePath := filepath.Join(homeBasePath, testFileName) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) testFileName2 := "test_file2.dat" testFilePath2 := filepath.Join(homeBasePath, testFileName2) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) err = createTestFile(testFilePath2, testFileSize2) assert.NoError(t, err) err = client.Mkdir(testDir) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, testDir)) assert.NoError(t, err) err = sftpUploadFile(testFilePath2, path.Join(testDir, testFileName2), testFileSize2, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath2, path.Join(testDir, testFileName2+".dupl"), testFileSize2, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath2, path.Join(vdirPath2, testDir, testFileName2), testFileSize2, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath2, path.Join(vdirPath2, testDir, testFileName2+".dupl"), testFileSize2, client) assert.NoError(t, err) // user quota: 2 files, size: 32768*2, folder2 quota: 2 files, size: 32768*2 // try to duplicate testDir, this will result in 4 file (over quota) and 32768*4 bytes (not over quota) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", testDir, testDir+"_copy"), user, usePubKey) //nolint:goconst assert.Error(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, testDir), path.Join(vdirPath2, testDir+"_copy")), user, usePubKey) assert.Error(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), user, usePubKey) assert.NoError(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) assert.NoError(t, err) // remove partially copied dirs _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir+"_copy"), user, usePubKey) assert.NoError(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir+"_copy")), user, usePubKey) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, user.UsedQuotaFiles) assert.Equal(t, int64(0), user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, f.UsedQuotaFiles) assert.Equal(t, int64(0), f.UsedQuotaSize) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 0, f.UsedQuotaFiles) assert.Equal(t, int64(0), f.UsedQuotaSize) err = client.Mkdir(path.Join(vdirPath1, testDir)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, testDir)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) // vdir1 is included in user quota, file limit will be exceeded _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir), "/"), user, usePubKey) assert.Error(t, err) // vdir2 size limit will be exceeded _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir, testFileName), vdirPath2+"/"), user, usePubKey) assert.Error(t, err) // now decrease the limits user.QuotaFiles = 1 user.QuotaSize = testFileSize * 10 for idx, f := range user.VirtualFolders { if f.Name == folderName2 { user.VirtualFolders[idx].QuotaSize = testFileSize user.VirtualFolders[idx].QuotaFiles = 10 } } user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) assert.Equal(t, 1, user.QuotaFiles) assert.Equal(t, testFileSize*10, user.QuotaSize) if assert.Len(t, user.VirtualFolders, 2) { for _, f := range user.VirtualFolders { if f.Name == folderName2 { assert.Equal(t, testFileSize, f.QuotaSize) assert.Equal(t, 10, f.QuotaFiles) } } } _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir), path.Join(vdirPath2, testDir+".copy")), user, usePubKey) assert.Error(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, testDir), testDir+".copy"), user, usePubKey) assert.Error(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) err = os.Remove(testFilePath2) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestSSHRemove(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1/sub" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2/sub" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 100, QuotaSize: 0, }) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) err = os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) testFileName1 := "test_file1.dat" testFileSize1 := int64(65537) testFilePath1 := filepath.Join(homeBasePath, testFileName1) testDir := "testdir" err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, testDir)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, testDir)) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) err = client.Symlink(testFileName, testFileName+".link") assert.NoError(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testFileName+".link"), user, usePubKey) assert.NoError(t, err) _, err = runSSHCommand("sftpgo-remove /vdir1", user, usePubKey) assert.Error(t, err) _, err = runSSHCommand("sftpgo-remove /", user, usePubKey) assert.Error(t, err) _, err = runSSHCommand("sftpgo-remove", user, usePubKey) assert.Error(t, err) out, err := runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testFileName), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) _, err := client.Stat(testFileName) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 3, user.UsedQuotaFiles) assert.Equal(t, testFileSize+2*testFileSize1, user.UsedQuotaSize) } out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath1, testDir)), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) _, err := client.Stat(path.Join(vdirPath1, testFileName)) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize1, user.UsedQuotaSize) } _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", vdirPath1), user, usePubKey) assert.Error(t, err) _, err = runSSHCommand("sftpgo-remove /", user, usePubKey) assert.Error(t, err) _, err = runSSHCommand("sftpgo-remove missing_file", user, usePubKey) assert.Error(t, err) if runtime.GOOS != osWindows { err = os.Chmod(filepath.Join(mappedPath2, testDir), 0555) assert.NoError(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) assert.Error(t, err) err = os.Chmod(filepath.Join(mappedPath2, testDir), 0001) assert.NoError(t, err) _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) assert.Error(t, err) err = os.Chmod(filepath.Join(mappedPath2, testDir), os.ModePerm) assert.NoError(t, err) } } // test remove dir with no delete perm user.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermDownload, dataprovider.PermListItems} _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) conn, client, err = getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() _, err = runSSHCommand("sftpgo-remove adir", user, usePubKey) assert.Error(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestSSHRemoveCryptFs(t *testing.T) { usePubKey := false u := getTestUserWithCryptFs(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1/sub" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2/sub" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 100, QuotaSize: 0, }) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() testDir := "tdir" err = client.Mkdir(testDir) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath1, testDir)) assert.NoError(t, err) err = client.Mkdir(path.Join(vdirPath2, testDir)) assert.NoError(t, err) testFileSize := int64(32768) testFileSize1 := int64(65536) testFileName1 := "test_file1.dat" err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(testFileName1, testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client) assert.NoError(t, err) _, err = runSSHCommand("sftpgo-remove /vdir2", user, usePubKey) assert.Error(t, err) out, err := runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testFileName), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) _, err := client.Stat(testFileName) assert.Error(t, err) } out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) } out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath1, testDir)), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) } out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir, testFileName)), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) } err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName), testFileSize, client) assert.NoError(t, err) out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) if assert.NoError(t, err) { assert.Equal(t, "OK\n", string(out)) } user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Greater(t, user.UsedQuotaSize, testFileSize1) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestSSHCommandMaxTransfers(t *testing.T) { if len(gitPath) == 0 || len(sshPath) == 0 || runtime.GOOS == osWindows { t.Skip("git and/or ssh command not found or OS is windows, unable to execute this test") } oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 2 usePubKey := true u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) repoName := "testrepo" //nolint:goconst clonePath := filepath.Join(homeBasePath, repoName) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(filepath.Join(homeBasePath, repoName)) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { f1, err := client.Create("file1") assert.NoError(t, err) f2, err := client.Create("file2") assert.NoError(t, err) _, err = f1.Write([]byte(" ")) assert.NoError(t, err) _, err = f2.Write([]byte(" ")) assert.NoError(t, err) _, err = client.Create("file3") assert.Error(t, err) err = f1.Close() assert.NoError(t, err) err = f2.Close() assert.NoError(t, err) err = client.Close() assert.NoError(t, err) err = conn.Close() assert.NoError(t, err) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(clonePath) assert.NoError(t, err) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.MaxPerHostConnections = oldValue } // Start SCP tests func TestSCPBasicHandling(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) u.QuotaSize = 6553600 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.QuotaSize = 6553600 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(131074) expectedQuotaSize := testFileSize expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) localPath := filepath.Join(homeBasePath, "scp_download.dat") // test to download a missing file err = scpDownload(localPath, remoteDownPath, false, false) assert.Error(t, err, "downloading a missing file via scp must fail") user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, int64(0), user.FirstDownload) err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.FirstUpload, int64(0)) assert.Equal(t, int64(0), user.FirstDownload) err = scpDownload(localPath, remoteDownPath, false, false) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.FirstUpload, int64(0)) assert.Greater(t, user.FirstDownload, int64(0)) fi, err := os.Stat(localPath) if assert.NoError(t, err) { assert.Equal(t, testFileSize, fi.Size()) } err = os.Remove(localPath) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestSCPUploadFileOverwrite(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) u.QuotaFiles = 1000 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser(usePubKey) u.QuotaFiles = 1000 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(32760) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) err = scpUpload(testFilePath, remoteUpPath, true, false) assert.NoError(t, err) // test a new upload that must overwrite the existing file err = scpUpload(testFilePath, remoteUpPath, true, false) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) localPath := filepath.Join(homeBasePath, "scp_download.dat") err = scpDownload(localPath, remoteDownPath, false, false) assert.NoError(t, err) fi, err := os.Stat(localPath) if assert.NoError(t, err) { assert.Equal(t, testFileSize, fi.Size()) } // now create a simlink via SFTP, replace the symlink with a file via SCP and check quota usage conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() err = client.Symlink(testFileName, testFileName+".link") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) } err = scpUpload(testFilePath, remoteUpPath+".link", true, false) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize*2, user.UsedQuotaSize) assert.Equal(t, 2, user.UsedQuotaFiles) err = os.Remove(localPath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) } func TestSCPRecursive(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) assert.NoError(t, err) testBaseDirName := "test_dir" testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) testBaseDirDownName := "test_dir_down" //nolint:goconst testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName) testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName) testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName) testFileSize := int64(131074) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testBaseDirName)) // test to download a missing dir err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true) assert.Error(t, err, "downloading a missing dir via scp must fail") remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") err = scpUpload(testBaseDirPath, remoteUpPath, true, false) assert.NoError(t, err) // overwrite existing dir err = scpUpload(testBaseDirPath, remoteUpPath, true, false) assert.NoError(t, err) err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true) assert.NoError(t, err) // test download without passing -r err = scpDownload(testBaseDirDownPath, remoteDownPath, true, false) assert.Error(t, err, "recursive download without -r must fail") fi, err := os.Stat(filepath.Join(testBaseDirDownPath, testFileName)) if assert.NoError(t, err) { assert.Equal(t, testFileSize, fi.Size()) } fi, err = os.Stat(filepath.Join(testBaseDirDownPath, testBaseDirName, testFileName)) if assert.NoError(t, err) { assert.Equal(t, testFileSize, fi.Size()) } // upload to a non existent dir remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/non_existent_dir") err = scpUpload(testBaseDirPath, remoteUpPath, true, false) assert.Error(t, err, "uploading via scp to a non existent dir must fail") err = os.RemoveAll(testBaseDirDownPath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } err = os.RemoveAll(testBaseDirPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) } func TestSCPStartDirectory(t *testing.T) { usePubKey := true startDir := "/sta rt/dir" u := getTestUser(usePubKey) u.Filters.StartDirectory = startDir user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) localPath := filepath.Join(homeBasePath, "scp_download.dat") remoteUpPath := fmt.Sprintf("%v@127.0.0.1:", user.Username) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) err = scpDownload(localPath, remoteDownPath, false, false) assert.NoError(t, err) // check that the file is in the start directory _, err = os.Stat(filepath.Join(user.HomeDir, startDir, testFileName)) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSCPPatternsFilter(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileSize := int64(131072) testFilePath := filepath.Join(homeBasePath, testFileName) localPath := filepath.Join(homeBasePath, "scp_download.dat") remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) user.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/", AllowedPatterns: []string{"*.zip"}, DeniedPatterns: []string{}, }, } _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) err = scpDownload(localPath, remoteDownPath, false, false) assert.Error(t, err, "scp download must fail") err = scpUpload(testFilePath, remoteUpPath, false, false) assert.Error(t, err, "scp upload must fail") _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = os.Stat(localPath) if err == nil { err = os.Remove(localPath) assert.NoError(t, err) } err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSCPTransferQuotaLimits(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) u.DownloadDataTransfer = 1 u.UploadDataTransfer = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(550000) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) err = scpDownload(localDownloadPath, remoteDownPath, false, false) assert.NoError(t, err) // error while download is active err = scpDownload(localDownloadPath, remoteDownPath, false, false) assert.Error(t, err) // error before starting the download err = scpDownload(localDownloadPath, remoteDownPath, false, false) assert.Error(t, err) // error while upload is active err = scpUpload(testFilePath, remoteUpPath, false, false) assert.Error(t, err) // error before starting the upload err = scpUpload(testFilePath, remoteUpPath, false, false) assert.Error(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.UsedDownloadDataTransfer, int64(1024*1024)) if !assert.Greater(t, user.UsedUploadDataTransfer, int64(1024*1024), user.UsedDownloadDataTransfer) { printLatestLogs(30) } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSCPUploadMaxSize(t *testing.T) { testFileSize := int64(65535) usePubKey := true u := getTestUser(usePubKey) u.Filters.MaxUploadFileSize = testFileSize + 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testFileSize1 := int64(131072) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") err = scpUpload(testFilePath1, remoteUpPath, false, false) assert.Error(t, err) err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSCPVirtualFolders(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) vdirPath := "/vdir" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) err = os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testBaseDirName := "test_dir" testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) testBaseDirDownName := "test_dir_down" testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName) testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName) testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName) testFileSize := int64(131074) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize) assert.NoError(t, err) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, vdirPath) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, vdirPath) err = scpUpload(testBaseDirPath, remoteUpPath, true, false) assert.NoError(t, err) err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(testBaseDirPath) assert.NoError(t, err) err = os.RemoveAll(testBaseDirDownPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestSCPNestedFolders(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) assert.NoError(t, err, string(resp)) usePubKey := true u := getTestUser(usePubKey) u.HomeDir += "_folders" u.Username += "_folders" mappedPathSFTP := filepath.Join(os.TempDir(), "sftp") folderNameSFTP := filepath.Base(mappedPathSFTP) vdirSFTPPath := "/vdir/sftp" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameSFTP, }, VirtualPath: vdirSFTPPath, }) mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") folderNameCrypt := filepath.Base(mappedPathCrypt) vdirCryptPath := "/vdir/crypt" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameCrypt, }, VirtualPath: vdirCryptPath, }) f1 := vfs.BaseVirtualFolder{ Name: folderNameSFTP, FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: baseUser.Username, }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderNameCrypt, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, MappedPath: mappedPathCrypt, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) baseDirDownPath := filepath.Join(os.TempDir(), "basedir-down") err = os.Mkdir(baseDirDownPath, os.ModePerm) assert.NoError(t, err) baseDir := filepath.Join(os.TempDir(), "basedir") err = os.Mkdir(baseDir, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(filepath.Join(baseDir, vdirSFTPPath), os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(filepath.Join(baseDir, vdirCryptPath), os.ModePerm) assert.NoError(t, err) err = createTestFile(filepath.Join(baseDir, vdirSFTPPath, testFileName), 32768) assert.NoError(t, err) err = createTestFile(filepath.Join(baseDir, vdirCryptPath, testFileName), 65535) assert.NoError(t, err) err = createTestFile(filepath.Join(baseDir, "vdir", testFileName), 65536) assert.NoError(t, err) remoteRootPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false) assert.NoError(t, err) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() defer client.Close() info, err := client.Stat(path.Join(vdirCryptPath, testFileName)) assert.NoError(t, err) assert.Equal(t, int64(65535), info.Size()) info, err = client.Stat(path.Join(vdirSFTPPath, testFileName)) assert.NoError(t, err) assert.Equal(t, int64(32768), info.Size()) info, err = client.Stat(path.Join("/vdir", testFileName)) assert.NoError(t, err) assert.Equal(t, int64(65536), info.Size()) } err = scpDownload(baseDirDownPath, remoteRootPath, true, true) assert.NoError(t, err) assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, "vdir", testFileName)) assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, vdirCryptPath, testFileName)) assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, vdirSFTPPath, testFileName)) if runtime.GOOS != osWindows { err = os.Chmod(filepath.Join(baseUser.GetHomeDir(), testFileName), 0001) assert.NoError(t, err) err = scpDownload(baseDirDownPath, remoteRootPath, true, true) assert.Error(t, err) err = os.Chmod(filepath.Join(baseUser.GetHomeDir(), testFileName), os.ModePerm) assert.NoError(t, err) } // now change the password for the base user, so SFTP folder will not work baseUser.Password = defaultPassword + "_mod" _, _, err = httpdtest.UpdateUser(baseUser, http.StatusOK, "1") assert.NoError(t, err) err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false) assert.Error(t, err) err = scpDownload(baseDirDownPath, remoteRootPath, true, true) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameSFTP}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(baseUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPathCrypt) assert.NoError(t, err) err = os.RemoveAll(mappedPathSFTP) assert.NoError(t, err) err = os.RemoveAll(baseDir) assert.NoError(t, err) err = os.RemoveAll(baseDirDownPath) assert.NoError(t, err) } func TestSCPVirtualFoldersQuota(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) u.QuotaFiles = 100 mappedPath1 := filepath.Join(os.TempDir(), "vdir1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vdir1" mappedPath2 := filepath.Join(os.TempDir(), "vdir2") folderName2 := filepath.Base(mappedPath2) vdirPath2 := "/vdir2" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName1, }, VirtualPath: vdirPath1, QuotaFiles: -1, QuotaSize: -1, }) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName2, }, VirtualPath: vdirPath2, QuotaFiles: 0, QuotaSize: 0, }) f1 := vfs.BaseVirtualFolder{ Name: folderName1, MappedPath: mappedPath1, } _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName2, MappedPath: mappedPath2, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) err = os.MkdirAll(mappedPath1, os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(mappedPath2, os.ModePerm) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testBaseDirName := "test_dir" testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) testBaseDirDownName := "test_dir_down" testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName) testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName) testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName) testFileSize := int64(131074) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize) assert.NoError(t, err) remoteDownPath1 := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", vdirPath1)) remoteUpPath1 := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, vdirPath1) remoteDownPath2 := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", vdirPath2)) remoteUpPath2 := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, vdirPath2) // we upload two times to test overwrite err = scpUpload(testBaseDirPath, remoteUpPath1, true, false) assert.NoError(t, err) err = scpDownload(testBaseDirDownPath, remoteDownPath1, true, true) assert.NoError(t, err) err = scpUpload(testBaseDirPath, remoteUpPath1, true, false) assert.NoError(t, err) err = scpDownload(testBaseDirDownPath, remoteDownPath1, true, true) assert.NoError(t, err) err = scpUpload(testBaseDirPath, remoteUpPath2, true, false) assert.NoError(t, err) err = scpDownload(testBaseDirDownPath, remoteDownPath2, true, true) assert.NoError(t, err) expectedQuotaFiles := 2 expectedQuotaSize := testFileSize * 2 user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), f.UsedQuotaSize) assert.Equal(t, 0, f.UsedQuotaFiles) f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaSize, f.UsedQuotaSize) assert.Equal(t, expectedQuotaFiles, f.UsedQuotaFiles) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(testBaseDirPath) assert.NoError(t, err) err = os.RemoveAll(testBaseDirDownPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath1) assert.NoError(t, err) err = os.RemoveAll(mappedPath2) assert.NoError(t, err) } func TestSCPPermsSubDirs(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermAny} u.Permissions["/somedir"] = []string{dataprovider.PermListItems, dataprovider.PermUpload} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) localPath := filepath.Join(homeBasePath, "scp_download.dat") subPath := filepath.Join(user.GetHomeDir(), "somedir") testFileSize := int64(65535) err = os.MkdirAll(subPath, os.ModePerm) assert.NoError(t, err) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/somedir") err = scpDownload(localPath, remoteDownPath, false, true) assert.Error(t, err, "download a dir with no permissions must fail") err = os.Remove(subPath) assert.NoError(t, err) err = createTestFile(subPath, testFileSize) assert.NoError(t, err) err = scpDownload(localPath, remoteDownPath, false, false) assert.NoError(t, err) if runtime.GOOS != osWindows { err = os.Chmod(subPath, 0001) assert.NoError(t, err) err = scpDownload(localPath, remoteDownPath, false, false) assert.Error(t, err, "download a file with no system permissions must fail") err = os.Chmod(subPath, os.ModePerm) assert.NoError(t, err) } err = os.Remove(localPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPPermCreateDirs(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(32760) testBaseDirName := "test_dir" testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = createTestFile(testFilePath1, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp/") err = scpUpload(testFilePath, remoteUpPath, true, false) assert.Error(t, err, "scp upload must fail, the user cannot create files in a missing dir") err = scpUpload(testBaseDirPath, remoteUpPath, true, false) assert.Error(t, err, "scp upload must fail, the user cannot create new dirs") err = os.Remove(testFilePath) assert.NoError(t, err) err = os.RemoveAll(testBaseDirPath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPPermUpload(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermCreateDirs} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65536) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp") err = scpUpload(testFilePath, remoteUpPath, true, false) assert.Error(t, err, "scp upload must fail, the user cannot upload") err = os.Remove(testFilePath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPPermOverwrite(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65536) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp") err = scpUpload(testFilePath, remoteUpPath, true, false) assert.NoError(t, err) err = scpUpload(testFilePath, remoteUpPath, true, false) assert.Error(t, err, "scp upload must fail, the user cannot ovewrite existing files") err = os.Remove(testFilePath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPPermDownload(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true u := getTestUser(usePubKey) u.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65537) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") err = scpUpload(testFilePath, remoteUpPath, true, false) assert.NoError(t, err) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) localPath := filepath.Join(homeBasePath, "scp_download.dat") err = scpDownload(localPath, remoteDownPath, false, false) assert.Error(t, err, "scp download must fail, the user cannot download") err = os.Remove(testFilePath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPQuotaSize(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true testFileSize := int64(65535) u := getTestUser(usePubKey) u.QuotaFiles = 1 u.QuotaSize = testFileSize + 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testFileSize1 := int64(131072) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) testFileSize2 := int64(32768) testFileName2 := "test_file2.dat" testFilePath2 := filepath.Join(homeBasePath, testFileName2) err = createTestFile(testFilePath2, testFileSize2) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) err = scpUpload(testFilePath, remoteUpPath, true, false) assert.NoError(t, err) err = scpUpload(testFilePath, remoteUpPath+".quota", true, false) assert.Error(t, err, "user is over quota scp upload must fail") // now test quota limits while uploading the current file, we have 1 bytes remaining user.QuotaSize = testFileSize + 1 user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) err = scpUpload(testFilePath2, remoteUpPath+".quota", true, false) assert.Error(t, err, "user is over quota scp upload must fail") // overwriting an existing file will work if the resulting size is lesser or equal than the current one err = scpUpload(testFilePath1, remoteUpPath, true, false) assert.Error(t, err) err = scpUpload(testFilePath2, remoteUpPath, true, false) assert.NoError(t, err) err = scpUpload(testFilePath, remoteUpPath, true, false) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) err = os.Remove(testFilePath2) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPEscapeHomeDir(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) testDir := "testDir" linkPath := filepath.Join(homeBasePath, defaultUsername, testDir) err = os.Symlink(homeBasePath, linkPath) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDir, testDir)) err = scpUpload(testFilePath, remoteUpPath, false, false) assert.Error(t, err, "uploading to a dir with a symlink outside home dir must fail") remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir, testFileName)) localPath := filepath.Join(homeBasePath, "scp_download.dat") err = scpDownload(localPath, remoteDownPath, false, false) assert.Error(t, err, "scp download must fail, the requested file has a symlink outside user home") remoteDownPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir)) err = scpDownload(homeBasePath, remoteDownPath, false, true) assert.Error(t, err, "scp download must fail, the requested dir is a symlink outside user home") err = os.Remove(testFilePath) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPUploadPaths(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) testDirName := "testDir" testDirPath := filepath.Join(user.GetHomeDir(), testDirName) err = os.MkdirAll(testDirPath, os.ModePerm) assert.NoError(t, err) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, testDirName) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testFileName)) localPath := filepath.Join(homeBasePath, "scp_download.dat") err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) err = scpDownload(localPath, remoteDownPath, false, false) assert.NoError(t, err) // upload a file to a missing dir remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testDirName, testFileName)) err = scpUpload(testFilePath, remoteUpPath, false, false) assert.Error(t, err, "scp upload to a missing dir must fail") err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.Remove(localPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPOverwriteDirWithFile(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) testDirPath := filepath.Join(user.GetHomeDir(), testFileName) err = os.MkdirAll(testDirPath, os.ModePerm) assert.NoError(t, err) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") err = scpUpload(testFilePath, remoteUpPath, false, false) assert.Error(t, err, "copying a file over an existing dir must fail") err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestSCPRemoteToRemote(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } if runtime.GOOS == osWindows { t.Skip("scp between remote hosts is not supported on Windows") } usePubKey := true user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) assert.NoError(t, err) u := getTestUser(usePubKey) u.Username += "1" u.HomeDir += "1" user1, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) remote1UpPath := fmt.Sprintf("%v@127.0.0.1:%v", user1.Username, path.Join("/", testFileName)) err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) err = scpUpload(remoteUpPath, remote1UpPath, false, true) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) } func TestSCPErrors(t *testing.T) { if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } u := getTestUser(true) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileSize := int64(524288) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) localPath := filepath.Join(homeBasePath, "scp_download.dat") err = scpUpload(testFilePath, remoteUpPath, false, false) assert.NoError(t, err) user.UploadBandwidth = 512 user.DownloadBandwidth = 512 _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) cmd := getScpDownloadCommand(localPath, remoteDownPath, false, false) go func() { err := cmd.Run() assert.Error(t, err, "SCP download must fail") }() waitForActiveTransfers(t) // wait some additional arbitrary time to wait for transfer activity to happen // it is need to reach all the code in CheckIdleConnections time.Sleep(100 * time.Millisecond) err = cmd.Process.Kill() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 2*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) cmd = getScpUploadCommand(testFilePath, remoteUpPath, false, false) go func() { err := cmd.Run() assert.Error(t, err, "SCP upload must fail") }() waitForActiveTransfers(t) // wait some additional arbitrary time to wait for transfer activity to happen // it is need to reach all the code in CheckIdleConnections time.Sleep(100 * time.Millisecond) err = cmd.Process.Kill() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 2*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(testFilePath) assert.NoError(t, err) os.Remove(localPath) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } // End SCP tests func waitTCPListening(address string) { for { conn, err := net.Dial("tcp", address) if err != nil { logger.WarnToConsole("tcp server %v not listening: %v", address, err) time.Sleep(100 * time.Millisecond) continue } logger.InfoToConsole("tcp server %v now listening", address) conn.Close() break } } func getTestGroup() dataprovider.Group { return dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "test_group", Description: "test group description", }, } } func getTestUser(usePubKey bool) dataprovider.User { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: defaultUsername, Password: defaultPassword, HomeDir: filepath.Join(homeBasePath, defaultUsername), Status: 1, ExpirationDate: 0, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = allPerms if usePubKey { user.PublicKeys = []string{testPubKey} user.Password = "" } return user } func getTestSFTPUser(usePubKey bool) dataprovider.User { u := getTestUser(usePubKey) u.Username = defaultSFTPUsername u.FsConfig.Provider = sdk.SFTPFilesystemProvider u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) if usePubKey { u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(testPrivateKey) u.FsConfig.SFTPConfig.Fingerprints = hostKeyFPs } return u } func runSSHCommand(command string, user dataprovider.User, usePubKey bool) ([]byte, error) { var sshSession *ssh.Session var output []byte config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } if usePubKey { key, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) if err != nil { return output, err } config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)} } else { config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { return output, err } defer conn.Close() sshSession, err = conn.NewSession() if err != nil { return output, err } var stdout, stderr bytes.Buffer sshSession.Stdout = &stdout sshSession.Stderr = &stderr err = sshSession.Run(command) if err != nil { return nil, fmt.Errorf("failed to run command %v: %v", command, stderr.Bytes()) } return stdout.Bytes(), err } func getSignerForUserCert(certBytes []byte) (ssh.Signer, error) { signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) if err != nil { return nil, err } cert, _, _, _, err := ssh.ParseAuthorizedKey(certBytes) //nolint:dogsled if err != nil { return nil, err } return ssh.NewCertSigner(cert.(*ssh.Certificate), signer) } func getSftpClientWithAddr(user dataprovider.User, usePubKey bool, addr string) (*ssh.Client, *sftp.Client, error) { var sftpClient *sftp.Client config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } if usePubKey { signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) if err != nil { return nil, nil, err } config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)} } else { if user.Password != "" { if user.Password == "empty" { config.Auth = []ssh.AuthMethod{ssh.Password("")} } else { config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} } } else { config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} } } conn, err := ssh.Dial("tcp", addr, config) if err != nil { return conn, sftpClient, err } sftpClient, err = sftp.NewClient(conn) if err != nil { conn.Close() } return conn, sftpClient, err } func getSftpClient(user dataprovider.User, usePubKey bool) (*ssh.Client, *sftp.Client, error) { return getSftpClientWithAddr(user, usePubKey, sftpServerAddr) } func getKeyboardInteractiveSftpClient(user dataprovider.User, answers []string) (*ssh.Client, *sftp.Client, error) { var sftpClient *sftp.Client config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Auth: []ssh.AuthMethod{ ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { return answers, nil }), }, Timeout: 5 * time.Second, } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { return nil, sftpClient, err } sftpClient, err = sftp.NewClient(conn) if err != nil { conn.Close() } return conn, sftpClient, err } func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMethod, addr string) (*ssh.Client, *sftp.Client, error) { var sftpClient *sftp.Client config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Auth: authMethods, Timeout: 5 * time.Second, } var err error var conn *ssh.Client if addr != "" { conn, err = ssh.Dial("tcp", addr, config) } else { conn, err = ssh.Dial("tcp", sftpServerAddr, config) } if err != nil { return conn, sftpClient, err } sftpClient, err = sftp.NewClient(conn) if err != nil { conn.Close() } return conn, sftpClient, err } func createTestFile(path string, size int64) error { baseDir := filepath.Dir(path) if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(baseDir, os.ModePerm) if err != nil { return err } } content := make([]byte, size) _, err := rand.Read(content) if err != nil { return err } return os.WriteFile(path, content, os.ModePerm) } func appendToTestFile(path string, size int64) error { content := make([]byte, size) _, err := rand.Read(content) if err != nil { return err } f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, os.ModePerm) if err != nil { return err } defer f.Close() written, err := io.Copy(f, bytes.NewReader(content)) if err != nil { return err } if written != size { return fmt.Errorf("write error, written: %v/%v", written, size) } return nil } func checkBasicSFTP(client *sftp.Client) error { _, err := client.Getwd() if err != nil { return err } _, err = client.ReadDir(".") return err } func writeSFTPFile(name string, size int64, client *sftp.Client) error { content := make([]byte, size) _, err := rand.Read(content) if err != nil { return err } f, err := client.Create(name) if err != nil { return err } _, err = io.Copy(f, bytes.NewBuffer(content)) if err != nil { f.Close() return err } err = f.Close() if err != nil { return err } info, err := client.Stat(name) if err != nil { return err } if info.Size() != size { return fmt.Errorf("file size mismatch, wanted %v, actual %v", size, info.Size()) } return nil } func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error { srcFile, err := os.Open(localSourcePath) if err != nil { return err } defer srcFile.Close() destFile, err := client.Create(remoteDestPath) if err != nil { return err } _, err = io.Copy(destFile, srcFile) if err != nil { destFile.Close() return err } // we need to close the file to trigger the server side close method // we cannot defer closing otherwise Stat will fail for upload atomic mode destFile.Close() if expectedSize > 0 { fi, err := client.Stat(remoteDestPath) if err != nil { return err } if fi.Size() != expectedSize { return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) } } return err } func sftpUploadResumeFile(localSourcePath string, remoteDestPath string, expectedSize int64, invalidOffset bool, //nolint:unparam client *sftp.Client) error { srcFile, err := os.Open(localSourcePath) if err != nil { return err } defer srcFile.Close() fi, err := client.Lstat(remoteDestPath) if err != nil { return err } if !invalidOffset { _, err = srcFile.Seek(fi.Size(), 0) if err != nil { return err } } destFile, err := client.OpenFile(remoteDestPath, os.O_WRONLY|os.O_APPEND) if err != nil { return err } if !invalidOffset { _, err = destFile.Seek(fi.Size(), 0) if err != nil { return err } } _, err = io.Copy(destFile, srcFile) if err != nil { destFile.Close() return err } // we need to close the file to trigger the server side close method // we cannot defer closing otherwise Stat will fail for upload atomic mode destFile.Close() if expectedSize > 0 { fi, err := client.Lstat(remoteDestPath) if err != nil { return err } if fi.Size() != expectedSize { return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) } } return err } func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error { downloadDest, err := os.Create(localDestPath) if err != nil { return err } defer downloadDest.Close() sftpSrcFile, err := client.Open(remoteSourcePath) if err != nil { return err } defer sftpSrcFile.Close() _, err = io.Copy(downloadDest, sftpSrcFile) if err != nil { return err } err = downloadDest.Sync() if err != nil { return err } if expectedSize > 0 { fi, err := downloadDest.Stat() if err != nil { return err } if fi.Size() != expectedSize { return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) } } return err } func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error { c := make(chan error, 1) go func() { c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client) }() return c } func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error { c := make(chan error, 1) go func() { c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client) }() return c } func scpUpload(localPath, remotePath string, preserveTime, remoteToRemote bool) error { cmd := getScpUploadCommand(localPath, remotePath, preserveTime, remoteToRemote) return cmd.Run() } func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error { cmd := getScpDownloadCommand(localPath, remotePath, preserveTime, recursive) return cmd.Run() } func getScpDownloadCommand(localPath, remotePath string, preserveTime, recursive bool) *exec.Cmd { var args []string if preserveTime { args = append(args, "-p") } if recursive { args = append(args, "-r") } if scpForce { args = append(args, "-O") } args = append(args, "-P") args = append(args, "2022") args = append(args, "-o") args = append(args, "StrictHostKeyChecking=no") args = append(args, "-i") args = append(args, privateKeyPath) args = append(args, remotePath) args = append(args, localPath) return exec.Command(scpPath, args...) } func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRemote bool) *exec.Cmd { var args []string if remoteToRemote { args = append(args, "-3") } if preserveTime { args = append(args, "-p") } fi, err := os.Stat(localPath) if err == nil { if fi.IsDir() { args = append(args, "-r") } } if scpForce { args = append(args, "-O") } args = append(args, "-P") args = append(args, "2022") args = append(args, "-o") args = append(args, "StrictHostKeyChecking=no") args = append(args, "-o") args = append(args, "HostKeyAlgorithms=+ssh-rsa") args = append(args, "-i") args = append(args, privateKeyPath) args = append(args, localPath) args = append(args, remotePath) return exec.Command(scpPath, args...) } func computeHashForFile(hasher hash.Hash, path string) (string, error) { hash := "" f, err := os.Open(path) if err != nil { return hash, err } defer f.Close() _, err = io.Copy(hasher, f) if err == nil { hash = fmt.Sprintf("%x", hasher.Sum(nil)) } return hash, err } func waitForActiveTransfers(t *testing.T) { assert.Eventually(t, func() bool { for _, stat := range common.Connections.GetStats("") { if len(stat.Transfers) > 0 { return true } } return false }, 1*time.Second, 50*time.Millisecond) } func checkSystemCommands() { var err error gitPath, err = exec.LookPath("git") if err != nil { logger.Warn(logSender, "", "unable to get git command. GIT tests will be skipped, err: %v", err) logger.WarnToConsole("unable to get git command. GIT tests will be skipped, err: %v", err) gitPath = "" } sshPath, err = exec.LookPath("ssh") if err != nil { logger.Warn(logSender, "", "unable to get ssh command. GIT tests will be skipped, err: %v", err) logger.WarnToConsole("unable to get ssh command. GIT tests will be skipped, err: %v", err) gitPath = "" } hookCmdPath, err = exec.LookPath("true") if err != nil { logger.Warn(logSender, "", "unable to get hook command: %v", err) logger.WarnToConsole("unable to get hook command: %v", err) } scpPath, err = exec.LookPath("scp") if err != nil { logger.Warn(logSender, "", "unable to get scp command. SCP tests will be skipped, err: %v", err) logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err) scpPath = "" } else { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() cmd := exec.CommandContext(ctx, scpPath, "-O") out, _ := cmd.CombinedOutput() scpForce = !strings.Contains(string(out), "option -- O") } } func getKeyboardInteractiveScriptForBuiltinChecks(addPasscode bool, result int) []byte { content := []byte("#!/bin/sh\n\n") echos := []bool{false} q, _ := json.Marshal([]string{"Password: "}) e, _ := json.Marshal(echos) content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v,\"check_password\":1}'\n", string(q), string(e)))...) content = append(content, []byte("read ANSWER\n\n")...) content = append(content, []byte("if test \"$ANSWER\" != \"OK\"; then\n")...) content = append(content, []byte("exit 1\n")...) content = append(content, []byte("fi\n\n")...) if addPasscode { q, _ := json.Marshal([]string{"Passcode: "}) content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v,\"check_password\":2}'\n", string(q), string(e)))...) content = append(content, []byte("read ANSWER\n\n")...) content = append(content, []byte("if test \"$ANSWER\" != \"OK\"; then\n")...) content = append(content, []byte("exit 1\n")...) content = append(content, []byte("fi\n\n")...) } content = append(content, []byte(fmt.Sprintf("echo '{\"auth_result\":%v}'\n", result))...) return content } func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJSONResponse bool, result int) []byte { content := []byte("#!/bin/sh\n\n") q, _ := json.Marshal(questions) echos := []bool{} for index := range questions { echos = append(echos, index%2 == 0) } e, _ := json.Marshal(echos) if nonJSONResponse { content = append(content, []byte(fmt.Sprintf("echo 'questions: %v echos: %v\n", string(q), string(e)))...) } else { content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v}'\n", string(q), string(e)))...) } for index := range questions { content = append(content, []byte(fmt.Sprintf("read ANSWER%v\n", index))...) } if sleepTime > 0 { content = append(content, []byte(fmt.Sprintf("sleep %v\n", sleepTime))...) } content = append(content, []byte(fmt.Sprintf("echo '{\"auth_result\":%v}'\n", result))...) return content } func getExtAuthScriptContent(user dataprovider.User, nonJSONResponse, emptyResponse bool, username string) []byte { extAuthContent := []byte("#!/bin/sh\n\n") if emptyResponse { return extAuthContent } extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("if test \"$SFTPGO_AUTHD_USERNAME\" = \"%v\"; then\n", user.Username))...) if username != "" { user.Username = username } u, _ := json.Marshal(user) if nonJSONResponse { extAuthContent = append(extAuthContent, []byte("echo 'text response'\n")...) } else { extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) } extAuthContent = append(extAuthContent, []byte("else\n")...) if nonJSONResponse { extAuthContent = append(extAuthContent, []byte("echo 'text response'\n")...) } else { extAuthContent = append(extAuthContent, []byte("echo '{\"username\":\"\"}'\n")...) } extAuthContent = append(extAuthContent, []byte("fi\n")...) return extAuthContent } func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte { content := []byte("#!/bin/sh\n\n") if nonJSONResponse { content = append(content, []byte("echo 'text response'\n")...) return content } if len(user.Username) > 0 { u, _ := json.Marshal(user) content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) } return content } func getExitCodeScriptContent(exitCode int) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) return content } func getCheckPwdScriptsContents(status int, toVerify string) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte(fmt.Sprintf("echo '{\"status\":%v,\"to_verify\":\"%v\"}'\n", status, toVerify))...) if status > 0 { content = append(content, []byte("exit 0")...) } else { content = append(content, []byte("exit 1")...) } return content } func printLatestLogs(maxNumberOfLines int) { var lines []string f, err := os.Open(logFilePath) if err != nil { return } defer f.Close() scanner := bufio.NewScanner(f) for scanner.Scan() { lines = append(lines, scanner.Text()+"\r\n") for len(lines) > maxNumberOfLines { lines = lines[1:] } } if scanner.Err() != nil { logger.WarnToConsole("Unable to print latest logs: %v", scanner.Err()) return } for _, line := range lines { logger.DebugToConsole("%s", line) } } func getHostKeyFingerprint(name string) (string, error) { privateBytes, err := os.ReadFile(name) if err != nil { return "", err } private, err := ssh.ParsePrivateKey(privateBytes) if err != nil { return "", err } return ssh.FingerprintSHA256(private.PublicKey()), nil } func getHostKeysFingerprints(hostKeys []string) { for _, k := range hostKeys { fp, err := getHostKeyFingerprint(filepath.Join(configDir, k)) if err != nil { logger.ErrorToConsole("unable to get fingerprint for host key %q: %v", k, err) os.Exit(1) } hostKeyFPs = append(hostKeyFPs, fp) } } func createInitialFiles(scriptArgs string) { pubKeyPath = filepath.Join(homeBasePath, "ssh_key.pub") privateKeyPath = filepath.Join(homeBasePath, "ssh_key") trustedCAUserKey = filepath.Join(homeBasePath, "ca_user_key") gitWrapPath = filepath.Join(homeBasePath, "gitwrap.sh") extAuthPath = filepath.Join(homeBasePath, "extauth.sh") preLoginPath = filepath.Join(homeBasePath, "prelogin.sh") postConnectPath = filepath.Join(homeBasePath, "postconnect.sh") checkPwdPath = filepath.Join(homeBasePath, "checkpwd.sh") preDownloadPath = filepath.Join(homeBasePath, "predownload.sh") preUploadPath = filepath.Join(homeBasePath, "preupload.sh") revokeUserCerts = filepath.Join(homeBasePath, "revoked_certs.json") err := os.WriteFile(pubKeyPath, []byte(testPubKey+"\n"), 0600) if err != nil { logger.WarnToConsole("unable to save public key to file: %v", err) } err = os.WriteFile(privateKeyPath, []byte(testPrivateKey+"\n"), 0600) if err != nil { logger.WarnToConsole("unable to save private key to file: %v", err) } err = os.WriteFile(gitWrapPath, []byte(fmt.Sprintf("%v -i %v -oStrictHostKeyChecking=no %v\n", sshPath, privateKeyPath, scriptArgs)), os.ModePerm) if err != nil { logger.WarnToConsole("unable to save gitwrap shell script: %v", err) } err = os.WriteFile(trustedCAUserKey, []byte(testCAUserKey), 0600) if err != nil { logger.WarnToConsole("unable to save trusted CA user key: %v", err) } err = os.WriteFile(revokeUserCerts, []byte(`[]`), 0644) if err != nil { logger.WarnToConsole("unable to save revoked user certs: %v", err) } } ================================================ FILE: internal/sftpd/ssh_cmd.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd import ( "crypto/md5" "crypto/sha1" "crypto/sha256" "crypto/sha512" "errors" "fmt" "hash" "io" "runtime/debug" "slices" "strings" "time" "github.com/google/shlex" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( scpCmdName = "scp" sshCommandLogSender = "SSHCommand" ) type sshCommand struct { command string args []string connection *Connection startTime time.Time } func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool { var msg sshSubsystemExecMsg if err := ssh.Unmarshal(payload, &msg); err == nil { name, args, err := parseCommandPayload(msg.Command) connection.Log(logger.LevelDebug, "new ssh command: %q args: %v num args: %d user: %s, error: %v", name, args, len(args), connection.User.Username, err) if err == nil && slices.Contains(enabledSSHCommands, name) { connection.command = msg.Command if name == scpCmdName && len(args) >= 2 { connection.SetProtocol(common.ProtocolSCP) scpCommand := scpCommand{ sshCommand: sshCommand{ command: name, connection: connection, startTime: time.Now(), args: args}, } go scpCommand.handle() //nolint:errcheck return true } if name != scpCmdName { connection.SetProtocol(common.ProtocolSSH) sshCommand := sshCommand{ command: name, connection: connection, startTime: time.Now(), args: args, } go sshCommand.handle() //nolint:errcheck return true } } else { connection.Log(logger.LevelInfo, "ssh command not enabled/supported: %q", name) } } err := connection.CloseFS() connection.Log(logger.LevelError, "unable to unmarshal ssh command, close fs, err: %v", err) return false } func (c *sshCommand) handle() (err error) { defer func() { if r := recover(); r != nil { logger.Error(logSender, "", "panic in handle ssh command: %q stack trace: %v", r, string(debug.Stack())) err = common.ErrGenericFailure } }() if err := common.Connections.Add(c.connection); err != nil { defer c.connection.CloseFS() //nolint:errcheck logger.Info(logSender, "", "unable to add SSH command connection: %v", err) return c.sendErrorResponse(err) } defer common.Connections.Remove(c.connection.GetID()) c.connection.UpdateLastActivity() if slices.Contains(sshHashCommands, c.command) { return c.handleHashCommands() } else if c.command == "cd" { c.sendExitStatus(nil) } else if c.command == "pwd" { // hard coded response to the start directory c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck c.sendExitStatus(nil) } else if c.command == "sftpgo-copy" { return c.handleSFTPGoCopy() } else if c.command == "sftpgo-remove" { return c.handleSFTPGoRemove() } return } func (c *sshCommand) handleSFTPGoCopy() error { sshSourcePath := c.getSourcePath() sshDestPath := c.getDestPath() if sshSourcePath == "" || sshDestPath == "" || len(c.args) != 2 { return c.sendErrorResponse(errors.New("usage sftpgo-copy ")) } c.connection.Log(logger.LevelDebug, "requested copy %q -> %q", sshSourcePath, sshDestPath) if err := c.connection.Copy(sshSourcePath, sshDestPath); err != nil { return c.sendErrorResponse(err) } c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck c.sendExitStatus(nil) return nil } func (c *sshCommand) handleSFTPGoRemove() error { sshDestPath, err := c.getRemovePath() if err != nil { return c.sendErrorResponse(err) } if err := c.connection.RemoveAll(sshDestPath); err != nil { return c.sendErrorResponse(err) } c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck c.sendExitStatus(nil) return nil } func (c *sshCommand) handleHashCommands() error { var h hash.Hash switch c.command { case "md5sum": h = md5.New() case "sha1sum": h = sha1.New() case "sha256sum": h = sha256.New() case "sha384sum": h = sha512.New384() default: h = sha512.New() } var response string if len(c.args) == 0 { // without args we need to read the string to hash from stdin buf := make([]byte, 4096) n, err := c.connection.channel.Read(buf) if err != nil && err != io.EOF { return c.sendErrorResponse(err) } h.Write(buf[:n]) //nolint:errcheck response = fmt.Sprintf("%x -\n", h.Sum(nil)) } else { sshPath := c.getDestPath() if ok, policy := c.connection.User.IsFileAllowed(sshPath); !ok { c.connection.Log(logger.LevelInfo, "hash not allowed for file %q", sshPath) return c.sendErrorResponse(c.connection.GetErrorForDeniedFile(policy)) } fs, fsPath, err := c.connection.GetFsAndResolvedPath(sshPath) if err != nil { return c.sendErrorResponse(err) } if !c.connection.User.HasPerm(dataprovider.PermListItems, sshPath) { return c.sendErrorResponse(c.connection.GetPermissionDeniedError()) } hash, err := c.computeHashForFile(fs, h, fsPath) if err != nil { return c.sendErrorResponse(c.connection.GetFsError(fs, err)) } response = fmt.Sprintf("%v %v\n", hash, sshPath) } c.connection.channel.Write([]byte(response)) //nolint:errcheck c.sendExitStatus(nil) return nil } // for the supported commands, the destination path, if any, is the last argument func (c *sshCommand) getDestPath() string { if len(c.args) == 0 { return "" } return c.cleanCommandPath(c.args[len(c.args)-1]) } // for the supported commands, the destination path, if any, is the second-last argument func (c *sshCommand) getSourcePath() string { if len(c.args) < 2 { return "" } return c.cleanCommandPath(c.args[len(c.args)-2]) } func (c *sshCommand) cleanCommandPath(name string) string { name = strings.Trim(name, "'") name = strings.Trim(name, "\"") result := c.connection.User.GetCleanedPath(name) if strings.HasSuffix(name, "/") && !strings.HasSuffix(result, "/") { result += "/" } return result } func (c *sshCommand) getRemovePath() (string, error) { sshDestPath := c.getDestPath() if sshDestPath == "" || len(c.args) != 1 { err := errors.New("usage sftpgo-remove ") return "", err } if len(sshDestPath) > 1 { sshDestPath = strings.TrimSuffix(sshDestPath, "/") } return sshDestPath, nil } func (c *sshCommand) sendErrorResponse(err error) error { errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), err) c.connection.channel.Write([]byte(errorString)) //nolint:errcheck c.sendExitStatus(err) return err } func (c *sshCommand) sendExitStatus(err error) { status := uint32(0) vCmdPath := c.getDestPath() cmdPath := "" targetPath := "" vTargetPath := "" if c.command == "sftpgo-copy" { vTargetPath = vCmdPath vCmdPath = c.getSourcePath() } if err != nil { status = uint32(1) c.connection.Log(logger.LevelError, "command failed: %q args: %v user: %s err: %v", c.command, c.args, c.connection.User.Username, err) } exitStatus := sshSubsystemExitStatus{ Status: status, } _, errClose := c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) c.connection.Log(logger.LevelDebug, "exit status sent, error: %v", errClose) c.connection.channel.Close() // for scp we notify single uploads/downloads if c.command != scpCmdName { elapsed := time.Since(c.startTime).Nanoseconds() / 1000000 metric.SSHCommandCompleted(err) if vCmdPath != "" { _, p, errFs := c.connection.GetFsAndResolvedPath(vCmdPath) if errFs == nil { cmdPath = p } } if vTargetPath != "" { _, p, errFs := c.connection.GetFsAndResolvedPath(vTargetPath) if errFs == nil { targetPath = p } } common.ExecuteActionNotification(c.connection.BaseConnection, common.OperationSSHCmd, cmdPath, vCmdPath, //nolint:errcheck targetPath, vTargetPath, c.command, 0, err, elapsed, nil) if err == nil { logger.CommandLog(sshCommandLogSender, cmdPath, targetPath, c.connection.User.Username, "", c.connection.ID, common.ProtocolSSH, -1, -1, "", "", c.connection.command, -1, c.connection.GetLocalAddress(), c.connection.GetRemoteAddress(), elapsed) } } } func (c *sshCommand) computeHashForFile(fs vfs.Fs, hasher hash.Hash, path string) (string, error) { hash := "" f, r, _, err := fs.Open(path, 0) if err != nil { return hash, err } var reader io.ReadCloser if f != nil { reader = f } else { reader = r } defer reader.Close() _, err = io.Copy(hasher, reader) if err == nil { hash = fmt.Sprintf("%x", hasher.Sum(nil)) } return hash, err } func parseCommandPayload(command string) (string, []string, error) { parts, err := shlex.Split(command) if err == nil && len(parts) == 0 { err = fmt.Errorf("invalid command: %q", command) } if err != nil { return "", []string{}, err } if len(parts) < 2 { return parts[0], []string{}, nil } return parts[0], parts[1:], nil } ================================================ FILE: internal/sftpd/transfer.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sftpd import ( "fmt" "io" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/vfs" ) type writerAtCloser interface { io.WriterAt io.Closer } type readerAtCloser interface { io.ReaderAt io.Closer } type failingReader struct { innerReader readerAtCloser errRead error } func (r *failingReader) ReadAt(_ []byte, _ int64) (n int, err error) { return 0, r.errRead } func (r *failingReader) Close() error { if r.innerReader == nil { return nil } return r.innerReader.Close() } // transfer defines the transfer details. // It implements the io.ReaderAt and io.WriterAt interfaces to handle SFTP downloads and uploads type transfer struct { *common.BaseTransfer writerAt writerAtCloser readerAt readerAtCloser isFinished bool } func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader, errForRead error) *transfer { var writer writerAtCloser var reader readerAtCloser if baseTransfer.File != nil { writer = baseTransfer.File if errForRead == nil { reader = baseTransfer.File } else { reader = &failingReader{ innerReader: baseTransfer.File, errRead: errForRead, } } } else if pipeWriter != nil { writer = pipeWriter } else if pipeReader != nil { if errForRead == nil { reader = pipeReader } else { reader = &failingReader{ innerReader: pipeReader, errRead: errForRead, } } } if baseTransfer.File == nil && errForRead != nil && pipeReader == nil { reader = &failingReader{ innerReader: nil, errRead: errForRead, } } return &transfer{ BaseTransfer: baseTransfer, writerAt: writer, readerAt: reader, isFinished: false, } } // ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent. // It handles download bandwidth throttling too func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) { t.Connection.UpdateLastActivity() n, err = t.readerAt.ReadAt(p, off) t.BytesSent.Add(int64(n)) if err == nil { err = t.CheckRead() } if err != nil && err != io.EOF { if t.GetType() == common.TransferDownload { t.TransferError(err) } err = t.ConvertError(err) return } t.HandleThrottle() return } // WriteAt writes len(p) bytes to the uploaded file starting at byte offset off and updates the bytes received. // It handles upload bandwidth throttling too func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) { t.Connection.UpdateLastActivity() if off < t.MinWriteOffset { err := fmt.Errorf("invalid write offset: %v minimum valid value: %v", off, t.MinWriteOffset) t.TransferError(err) return 0, err } n, err = t.writerAt.WriteAt(p, off) t.BytesReceived.Add(int64(n)) if err == nil { err = t.CheckWrite() } if err != nil { t.TransferError(err) err = t.ConvertError(err) return } t.HandleThrottle() return } // Close it is called when the transfer is completed. // It closes the underlying file, logs the transfer info, updates the user quota (for uploads) // and executes any defined action. // If there is an error no action will be executed and, in atomic mode, we try to delete // the temporary file func (t *transfer) Close() error { if err := t.setFinished(); err != nil { return err } err := t.closeIO() errBaseClose := t.BaseTransfer.Close() if errBaseClose != nil { err = errBaseClose } return t.Connection.GetFsError(t.Fs, err) } func (t *transfer) closeIO() error { var err error if t.File != nil { err = t.File.Close() } else if t.writerAt != nil { err = t.writerAt.Close() t.Lock() // we set ErrTransfer here so quota is not updated, in this case the uploads are atomic if err != nil && t.ErrTransfer == nil { t.ErrTransfer = err } t.Unlock() } else if t.readerAt != nil { err = t.readerAt.Close() if metadater, ok := t.readerAt.(vfs.Metadater); ok { t.SetMetadata(metadater.Metadata()) } } return err } func (t *transfer) setFinished() error { t.Lock() defer t.Unlock() if t.isFinished { return common.ErrTransferClosed } t.isFinished = true return nil } ================================================ FILE: internal/smtp/oauth2.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package smtp provides supports for sending emails package smtp import ( "context" "errors" "fmt" "slices" "sync" "time" "golang.org/x/oauth2" "golang.org/x/oauth2/google" "golang.org/x/oauth2/microsoft" "github.com/drakkan/sftpgo/v2/internal/logger" ) // Supported OAuth2 providers const ( OAuth2ProviderGoogle = iota OAuth2ProviderMicrosoft ) var supportedOAuth2Providers = []int{OAuth2ProviderGoogle, OAuth2ProviderMicrosoft} // OAuth2Config defines OAuth2 settings type OAuth2Config struct { Provider int `json:"provider" mapstructure:"provider"` // Tenant for Microsoft provider, if empty "common" is used Tenant string `json:"tenant" mapstructure:"tenant"` // ClientID is the application's ID ClientID string `json:"client_id" mapstructure:"client_id"` // ClientSecret is the application's secret ClientSecret string `json:"client_secret" mapstructure:"client_secret"` // Token to use to get/renew access tokens RefreshToken string `json:"refresh_token" mapstructure:"refresh_token"` mu *sync.RWMutex config *oauth2.Config accessToken *oauth2.Token } // Validate validates and initializes the configuration func (c *OAuth2Config) Validate() error { if !slices.Contains(supportedOAuth2Providers, c.Provider) { return fmt.Errorf("smtp oauth2: unsupported provider %d", c.Provider) } if c.ClientID == "" { return errors.New("smtp oauth2: client id is required") } if c.ClientSecret == "" { return errors.New("smtp oauth2: client secret is required") } if c.RefreshToken == "" { return errors.New("smtp oauth2: refresh token is required") } c.initialize() return nil } func (c *OAuth2Config) isEqual(other *OAuth2Config) bool { if c.Provider != other.Provider { return false } if c.Tenant != other.Tenant { return false } if c.ClientID != other.ClientID { return false } if c.ClientSecret != other.ClientSecret { return false } if c.RefreshToken != other.RefreshToken { return false } return true } func (c *OAuth2Config) getAccessToken() (string, error) { c.mu.RLock() if c.accessToken.Expiry.After(time.Now().Add(30 * time.Second)) { accessToken := c.accessToken.AccessToken c.mu.RUnlock() return accessToken, nil } logger.Debug(logSender, "", "renew oauth2 token required, current token expires at %s", c.accessToken.Expiry) token := new(oauth2.Token) *token = *c.accessToken c.mu.RUnlock() ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() newToken, err := c.config.TokenSource(ctx, token).Token() if err != nil { logger.Error(logSender, "", "unable to get new token: %v", err) return "", err } accessToken := newToken.AccessToken refreshToken := newToken.RefreshToken if refreshToken != "" && refreshToken != token.RefreshToken { c.mu.Lock() c.RefreshToken = refreshToken c.accessToken = newToken c.mu.Unlock() logger.Debug(logSender, "", "oauth2 refresh token changed") go updateRefreshToken(refreshToken) } if accessToken != token.AccessToken { c.mu.Lock() c.accessToken = newToken c.mu.Unlock() logger.Debug(logSender, "", "new oauth2 token saved, expires at %s", c.accessToken.Expiry) } return accessToken, nil } func (c *OAuth2Config) initialize() { c.mu = new(sync.RWMutex) c.config = c.GetOAuth2() c.accessToken = &oauth2.Token{ TokenType: "Bearer", RefreshToken: c.RefreshToken, } } // GetOAuth2 returns the oauth2 configuration for the provided parameters. func (c *OAuth2Config) GetOAuth2() *oauth2.Config { var endpoint oauth2.Endpoint var scopes []string switch c.Provider { case OAuth2ProviderMicrosoft: endpoint = microsoft.AzureADEndpoint(c.Tenant) scopes = []string{"offline_access", "https://outlook.office.com/SMTP.Send"} default: endpoint = google.Endpoint scopes = []string{"https://mail.google.com/"} } return &oauth2.Config{ ClientID: c.ClientID, ClientSecret: c.ClientSecret, Scopes: scopes, Endpoint: endpoint, } } ================================================ FILE: internal/smtp/smtp.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package smtp provides supports for sending emails package smtp import ( "bytes" "context" "errors" "fmt" "html/template" "path/filepath" "sync" "time" "github.com/rs/xid" "github.com/wneessen/go-mail" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( logSender = "smtp" ) // EmailContentType defines the support content types for email body type EmailContentType int // Supported email body content type const ( EmailContentTypeTextPlain EmailContentType = iota EmailContentTypeTextHTML ) const ( templateEmailDir = "email" templatePasswordReset = "reset-password.html" templatePasswordExpiration = "password-expiration.html" dialTimeout = 10 * time.Second ) var ( config = &activeConfig{} initialConfig *Config emailTemplates = make(map[string]*template.Template) ) type activeConfig struct { sync.RWMutex config *Config } func (c *activeConfig) isEnabled() bool { c.RLock() defer c.RUnlock() return c.config != nil && c.config.Host != "" } func (c *activeConfig) Set(cfg *dataprovider.SMTPConfigs) { var config *Config if cfg != nil { config = &Config{ Host: cfg.Host, Port: cfg.Port, From: cfg.From, User: cfg.User, Password: cfg.Password.GetPayload(), AuthType: cfg.AuthType, Encryption: cfg.Encryption, Domain: cfg.Domain, Debug: cfg.Debug, OAuth2: OAuth2Config{ Provider: cfg.OAuth2.Provider, Tenant: cfg.OAuth2.Tenant, ClientID: cfg.OAuth2.ClientID, ClientSecret: cfg.OAuth2.ClientSecret.GetPayload(), RefreshToken: cfg.OAuth2.RefreshToken.GetPayload(), }, } config.OAuth2.initialize() } c.Lock() defer c.Unlock() if config != nil && config.Host != "" { if c.config != nil && c.config.isEqual(config) { return } c.config = config logger.Info(logSender, "", "activated new config, server %s:%d", c.config.Host, c.config.Port) } else { logger.Debug(logSender, "", "activating initial config") c.config = initialConfig if c.config == nil || c.config.Host == "" { logger.Debug(logSender, "", "configuration disabled, email capabilities will not be available") } } } func (c *activeConfig) getSMTPClientAndMsg(to, bcc []string, subject, body string, contentType EmailContentType, attachments ...*mail.File, ) (*mail.Client, *mail.Msg, error) { c.RLock() defer c.RUnlock() if c.config == nil || c.config.Host == "" { return nil, nil, errors.New("smtp: not configured") } return c.config.getSMTPClientAndMsg(to, bcc, subject, body, contentType, attachments...) } func (c *activeConfig) sendEmail(to, bcc []string, subject, body string, contentType EmailContentType, attachments ...*mail.File) error { client, msg, err := c.getSMTPClientAndMsg(to, bcc, subject, body, contentType, attachments...) if err != nil { return err } ctx, cancelFn := context.WithTimeout(context.Background(), dialTimeout) defer cancelFn() return client.DialAndSendWithContext(ctx, msg) } // IsEnabled returns true if an SMTP server is configured func IsEnabled() bool { return config.isEnabled() } // Activate sets the specified config as active func Activate(c *dataprovider.SMTPConfigs) { config.Set(c) } // Config defines the SMTP configuration to use to send emails type Config struct { // Location of SMTP email server. Leavy empty to disable email sending capabilities Host string `json:"host" mapstructure:"host"` // Port of SMTP email server Port int `json:"port" mapstructure:"port"` // From address, for example "SFTPGo ". // Many SMTP servers reject emails without a `From` header so, if not set, // SFTPGo will try to use the username as fallback, this may or may not be appropriate From string `json:"from" mapstructure:"from"` // SMTP username User string `json:"user" mapstructure:"user"` // SMTP password. Leaving both username and password empty the SMTP authentication // will be disabled Password string `json:"password" mapstructure:"password"` // 0 Plain // 1 Login // 2 CRAM-MD5 // 3 OAuth2 AuthType int `json:"auth_type" mapstructure:"auth_type"` // 0 no encryption // 1 TLS // 2 start TLS Encryption int `json:"encryption" mapstructure:"encryption"` // Domain to use for HELO command, if empty localhost will be used Domain string `json:"domain" mapstructure:"domain"` // Path to the email templates. This can be an absolute path or a path relative to the config dir. // Templates are searched within a subdirectory named "email" in the specified path TemplatesPath string `json:"templates_path" mapstructure:"templates_path"` // Set to 1 to enable debug logs Debug int `json:"debug" mapstructure:"debug"` // OAuth2 related settings OAuth2 OAuth2Config `json:"oauth2" mapstructure:"oauth2"` } func (c *Config) isEqual(other *Config) bool { if c.Host != other.Host { return false } if c.Port != other.Port { return false } if c.From != other.From { return false } if c.User != other.User { return false } if c.Password != other.Password { return false } if c.AuthType != other.AuthType { return false } if c.Encryption != other.Encryption { return false } if c.Domain != other.Domain { return false } if c.Debug != other.Debug { return false } return c.OAuth2.isEqual(&other.OAuth2) } func (c *Config) validate() error { if c.Port <= 0 || c.Port > 65535 { return fmt.Errorf("smtp: invalid port %d", c.Port) } if c.AuthType < 0 || c.AuthType > 3 { return fmt.Errorf("smtp: invalid auth type %d", c.AuthType) } if c.Encryption < 0 || c.Encryption > 2 { return fmt.Errorf("smtp: invalid encryption %d", c.Encryption) } if c.From == "" && c.User == "" { return errors.New(`smtp: from address and user cannot both be empty`) } if c.AuthType == 3 { return c.OAuth2.Validate() } return nil } func (c *Config) loadTemplates(configDir string) error { if c.TemplatesPath == "" { logger.Debug(logSender, "", "templates path empty, using default") c.TemplatesPath = "templates" } templatesPath := util.FindSharedDataPath(c.TemplatesPath, configDir) if templatesPath == "" { return fmt.Errorf("smtp: invalid templates path %q", templatesPath) } loadTemplates(filepath.Join(templatesPath, templateEmailDir)) return nil } // Initialize initializes and validates the SMTP configuration func (c *Config) Initialize(configDir string, isService bool) error { if !isService && c.Host == "" { if err := loadConfigFromProvider(); err != nil { return err } if !config.isEnabled() { return nil } // If not running as a service, templates will only be loaded if required. return c.loadTemplates(configDir) } // In service mode SMTP can be enabled from the WebAdmin at runtime so we // always load templates. if err := c.loadTemplates(configDir); err != nil { return err } if c.Host == "" { return loadConfigFromProvider() } if err := c.validate(); err != nil { return err } initialConfig = c config.Set(nil) logger.Debug(logSender, "", "configuration successfully initialized, host: %q, port: %d, username: %q, auth: %d, encryption: %d, helo: %q", c.Host, c.Port, c.User, c.AuthType, c.Encryption, c.Domain) return loadConfigFromProvider() } func (c *Config) getMailClientOptions() []mail.Option { options := []mail.Option{mail.WithPort(c.Port), mail.WithoutNoop()} switch c.Encryption { case 1: options = append(options, mail.WithSSL()) case 2: options = append(options, mail.WithTLSPolicy(mail.TLSMandatory)) default: options = append(options, mail.WithTLSPolicy(mail.NoTLS)) } if c.User != "" { options = append(options, mail.WithUsername(c.User)) } if c.Password != "" { options = append(options, mail.WithPassword(c.Password)) } if c.User != "" || c.Password != "" { switch c.AuthType { case 1: options = append(options, mail.WithSMTPAuth(mail.SMTPAuthLogin)) case 2: options = append(options, mail.WithSMTPAuth(mail.SMTPAuthCramMD5)) case 3: options = append(options, mail.WithSMTPAuth(mail.SMTPAuthXOAUTH2)) default: options = append(options, mail.WithSMTPAuth(mail.SMTPAuthPlain)) } } if c.Domain != "" { options = append(options, mail.WithHELO(c.Domain)) } if c.Debug > 0 { options = append(options, mail.WithLogger(&logger.MailAdapter{ ConnectionID: xid.New().String(), }), mail.WithDebugLog()) } return options } func (c *Config) getSMTPClientAndMsg(to, bcc []string, subject, body string, contentType EmailContentType, attachments ...*mail.File) (*mail.Client, *mail.Msg, error) { msg := mail.NewMsg() msg.SetUserAgent(version.GetServerVersion(" ", false)) var from string if c.From != "" { from = c.From } else { from = c.User } if err := msg.From(from); err != nil { return nil, nil, fmt.Errorf("invalid from address: %w", err) } if err := msg.To(to...); err != nil { return nil, nil, err } if len(bcc) > 0 { if err := msg.Bcc(bcc...); err != nil { return nil, nil, err } } msg.Subject(subject) msg.SetDate() msg.SetMessageID() msg.SetAttachments(attachments) switch contentType { case EmailContentTypeTextPlain: msg.SetBodyString(mail.TypeTextPlain, body) case EmailContentTypeTextHTML: msg.SetBodyString(mail.TypeTextHTML, body) default: return nil, nil, fmt.Errorf("smtp: unsupported body content type %v", contentType) } client, err := mail.NewClient(c.Host, c.getMailClientOptions()...) if err != nil { return nil, nil, fmt.Errorf("unable to create mail client: %w", err) } if c.AuthType == 3 { token, err := c.OAuth2.getAccessToken() if err != nil { return nil, nil, fmt.Errorf("unable to get oauth2 access token: %w", err) } client.SetPassword(token) } return client, msg, nil } // SendEmail tries to send an email using the specified parameters func (c *Config) SendEmail(to, bcc []string, subject, body string, contentType EmailContentType, attachments ...*mail.File) error { client, msg, err := c.getSMTPClientAndMsg(to, bcc, subject, body, contentType, attachments...) if err != nil { return err } ctx, cancelFn := context.WithTimeout(context.Background(), dialTimeout) defer cancelFn() return client.DialAndSendWithContext(ctx, msg) } func loadTemplates(templatesPath string) { logger.Debug(logSender, "", "loading templates from %q", templatesPath) passwordResetPath := filepath.Join(templatesPath, templatePasswordReset) pwdResetTmpl := util.LoadTemplate(nil, passwordResetPath) passwordExpirationPath := filepath.Join(templatesPath, templatePasswordExpiration) pwdExpirationTmpl := util.LoadTemplate(nil, passwordExpirationPath) emailTemplates[templatePasswordReset] = pwdResetTmpl emailTemplates[templatePasswordExpiration] = pwdExpirationTmpl } // RenderPasswordResetTemplate executes the password reset template func RenderPasswordResetTemplate(buf *bytes.Buffer, data any) error { if !IsEnabled() { return errors.New("smtp: not configured") } return emailTemplates[templatePasswordReset].Execute(buf, data) } // RenderPasswordExpirationTemplate executes the password expiration template func RenderPasswordExpirationTemplate(buf *bytes.Buffer, data any) error { if !IsEnabled() { return errors.New("smtp: not configured") } return emailTemplates[templatePasswordExpiration].Execute(buf, data) } // SendEmail tries to send an email using the specified parameters. func SendEmail(to, bcc []string, subject, body string, contentType EmailContentType, attachments ...*mail.File) error { return config.sendEmail(to, bcc, subject, body, contentType, attachments...) } func loadConfigFromProvider() error { configs, err := dataprovider.GetConfigs() if err != nil { logger.Error(logSender, "", "unable to load config from provider: %v", err) return fmt.Errorf("smtp: unable to load config from provider: %w", err) } configs.SetNilsToEmpty() if err := configs.SMTP.TryDecrypt(); err != nil { logger.Error(logSender, "", "unable to decrypt smtp config: %v", err) return fmt.Errorf("smtp: unable to decrypt smtp config: %w", err) } config.Set(configs.SMTP) return nil } func updateRefreshToken(token string) { configs, err := dataprovider.GetConfigs() if err != nil { logger.Error(logSender, "", "unable to load config from provider, updating refresh token not possible: %v", err) return } configs.SetNilsToEmpty() if configs.SMTP.IsEmpty() { logger.Warn(logSender, "", "unable to update refresh token, smtp not configured in the data provider") return } configs.SMTP.OAuth2.RefreshToken = kms.NewPlainSecret(token) if err := dataprovider.UpdateConfigs(&configs, dataprovider.ActionExecutorSystem, "", ""); err != nil { logger.Error(logSender, "", "unable to save new refresh token: %v", err) return } logger.Info(logSender, "", "refresh token updated") } ================================================ FILE: internal/telemetry/router.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package telemetry import ( "net/http" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" ) func initializeRouter(enableProfiler bool) { router = chi.NewRouter() router.Use(middleware.GetHead) router.Use(logger.NewStructuredLogger(logger.GetLogger())) router.Use(middleware.Recoverer) router.Group(func(r chi.Router) { r.Get("/healthz", func(w http.ResponseWriter, r *http.Request) { render.PlainText(w, r, "ok") }) }) router.Group(func(router chi.Router) { router.Use(checkAuth) metric.AddMetricsEndpoint(metricsPath, router) if enableProfiler { logger.InfoToConsole("enabling the built-in profiler") logger.Info(logSender, "", "enabling the built-in profiler") router.Mount(pprofBasePath, middleware.Profiler()) } }) } func checkAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !validateCredentials(r) { w.Header().Set(common.HTTPAuthenticationHeader, "Basic realm=\"SFTPGo telemetry\"") http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } func validateCredentials(r *http.Request) bool { if !httpAuth.IsEnabled() { return true } username, password, ok := r.BasicAuth() if !ok { return false } return httpAuth.ValidateCredentials(username, password) } ================================================ FILE: internal/telemetry/telemetry.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package telemetry provides telemetry information for SFTPGo, such as: // - health information (for health checks) // - metrics // - profiling information package telemetry import ( "crypto/tls" "log" "net/http" "path/filepath" "runtime" "time" "github.com/go-chi/chi/v5" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( logSender = "telemetry" metricsPath = "/metrics" pprofBasePath = "/debug" ) var ( router *chi.Mux httpAuth common.HTTPAuthProvider certMgr *common.CertManager ) // Conf telemetry server configuration. type Conf struct { // The port used for serving HTTP requests. 0 disable the HTTP server. Default: 0 BindPort int `json:"bind_port" mapstructure:"bind_port"` // The address to listen on. A blank value means listen on all available network interfaces. Default: "127.0.0.1" BindAddress string `json:"bind_address" mapstructure:"bind_address"` // Enable the built-in profiler. // The profiler will be accessible via HTTP/HTTPS using the base URL "/debug/pprof/" EnableProfiler bool `json:"enable_profiler" mapstructure:"enable_profiler"` // Path to a file used to store usernames and password for basic authentication. // This can be an absolute path or a path relative to the config dir. // We support HTTP basic authentication and the file format must conform to the one generated using the Apache // htpasswd tool. The supported password formats are bcrypt ($2y$ prefix) and md5 crypt ($apr1$ prefix). // If empty HTTP authentication is disabled AuthUserFile string `json:"auth_user_file" mapstructure:"auth_user_file"` // If files containing a certificate and matching private key for the server are provided the server will expect // HTTPS connections. // Certificate and key files can be reloaded on demand sending a "SIGHUP" signal on Unix based systems and a // "paramchange" request to the running service on Windows. CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` // TLSCipherSuites is a list of supported cipher suites for TLS version 1.2. // If CipherSuites is nil/empty, a default list of secure cipher suites // is used, with a preference order based on hardware performance. // Note that TLS 1.3 ciphersuites are not configurable. // The supported ciphersuites names are defined here: // // https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53 // // any invalid name will be silently ignored. // The order matters, the ciphers listed first will be the preferred ones. TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` // Defines the minimum TLS version. 13 means TLS 1.3, default is TLS 1.2 MinTLSVersion int `json:"min_tls_version" mapstructure:"min_tls_version"` // HTTP protocols to enable in preference order. Supported values: http/1.1, h2 Protocols []string `json:"tls_protocols" mapstructure:"tls_protocols"` } // ShouldBind returns true if there service must be started func (c Conf) ShouldBind() bool { if c.BindPort > 0 { return true } if filepath.IsAbs(c.BindAddress) && runtime.GOOS != "windows" { return true } return false } // Initialize configures and starts the telemetry server. func (c Conf) Initialize(configDir string) error { var err error logger.Info(logSender, "", "initializing telemetry server with config %+v", c) authUserFile := getConfigPath(c.AuthUserFile, configDir) httpAuth, err = common.NewBasicAuthProvider(authUserFile) if err != nil { return err } certificateFile := getConfigPath(c.CertificateFile, configDir) certificateKeyFile := getConfigPath(c.CertificateKeyFile, configDir) initializeRouter(c.EnableProfiler) httpServer := &http.Server{ Handler: router, ReadHeaderTimeout: 30 * time.Second, ReadTimeout: 60 * time.Second, WriteTimeout: 60 * time.Second, IdleTimeout: 60 * time.Second, MaxHeaderBytes: 1 << 14, // 16KB ErrorLog: log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0), } if certificateFile != "" && certificateKeyFile != "" { keyPairs := []common.TLSKeyPair{ { Cert: certificateFile, Key: certificateKeyFile, ID: common.DefaultTLSKeyPaidID, }, } certMgr, err = common.NewCertManager(keyPairs, configDir, logSender) if err != nil { return err } config := &tls.Config{ GetCertificate: certMgr.GetCertificateFunc(common.DefaultTLSKeyPaidID), MinVersion: util.GetTLSVersion(c.MinTLSVersion), NextProtos: util.GetALPNProtocols(c.Protocols), CipherSuites: util.GetTLSCiphersFromNames(c.TLSCipherSuites), } logger.Debug(logSender, "", "configured TLS cipher suites: %v", config.CipherSuites) httpServer.TLSConfig = config return util.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, true, nil, logSender) } return util.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, false, nil, logSender) } // ReloadCertificateMgr reloads the certificate manager func ReloadCertificateMgr() error { if certMgr != nil { return certMgr.Reload() } return nil } func getConfigPath(name, configDir string) string { if !util.IsFileInputValid(name) { return "" } if name != "" && !filepath.IsAbs(name) { return filepath.Join(configDir, name) } return name } ================================================ FILE: internal/telemetry/telemetry_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package telemetry import ( "net/http" "net/http/httptest" "os" "path/filepath" "runtime" "testing" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" ) const ( httpsCert = `-----BEGIN CERTIFICATE----- MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM 3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY /8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI dV4vKmHUzwK/eIx+8Ay3neE= -----END CERTIFICATE-----` httpsKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= -----END EC PRIVATE KEY-----` ) func TestInitialization(t *testing.T) { configDir := filepath.Join(".", "..", "..") providerConf := dataprovider.Config{ Driver: dataprovider.MemoryDataProviderName, BackupsPath: "backups", } err := dataprovider.Initialize(providerConf, configDir, false) require.NoError(t, err) commonConfig := common.Configuration{} err = common.Initialize(commonConfig, 0) require.NoError(t, err) c := Conf{ BindPort: 10000, BindAddress: "invalid address", EnableProfiler: false, } err = c.Initialize(configDir) require.Error(t, err) c.AuthUserFile = "missing" err = c.Initialize(".") require.Error(t, err) err = ReloadCertificateMgr() require.NoError(t, err) c.AuthUserFile = "" c.CertificateFile = "crt" c.CertificateKeyFile = "key" err = c.Initialize(".") require.Error(t, err) certPath := filepath.Join(os.TempDir(), "test.crt") keyPath := filepath.Join(os.TempDir(), "test.key") err = os.WriteFile(certPath, []byte(httpsCert), os.ModePerm) require.NoError(t, err) err = os.WriteFile(keyPath, []byte(httpsKey), os.ModePerm) require.NoError(t, err) c.CertificateFile = certPath c.CertificateKeyFile = keyPath err = c.Initialize(".") require.Error(t, err) err = ReloadCertificateMgr() require.NoError(t, err) err = os.Remove(certPath) require.NoError(t, err) err = os.Remove(keyPath) require.NoError(t, err) } func TestShouldBind(t *testing.T) { c := Conf{ BindPort: 10000, EnableProfiler: false, } require.True(t, c.ShouldBind()) c.BindPort = 0 require.False(t, c.ShouldBind()) if runtime.GOOS != "windows" { c.BindAddress = "/absolute/path" require.True(t, c.ShouldBind()) } } func TestRouter(t *testing.T) { authUserFile := filepath.Join(os.TempDir(), "http_users.txt") authUserData := []byte("test1:$2y$05$bcHSED7aO1cfLto6ZdDBOOKzlwftslVhtpIkRhAtSa4GuLmk5mola\n") err := os.WriteFile(authUserFile, authUserData, os.ModePerm) require.NoError(t, err) httpAuth, err = common.NewBasicAuthProvider(authUserFile) require.NoError(t, err) initializeRouter(true) testServer := httptest.NewServer(router) defer testServer.Close() req, err := http.NewRequest(http.MethodGet, "/healthz", nil) require.NoError(t, err) rr := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) require.Equal(t, "ok", rr.Body.String()) req, err = http.NewRequest(http.MethodGet, "/metrics", nil) require.NoError(t, err) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) require.Equal(t, http.StatusUnauthorized, rr.Code) req.SetBasicAuth("test1", "password1") rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) req, err = http.NewRequest(http.MethodGet, pprofBasePath+"/pprof/", nil) require.NoError(t, err) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) require.Equal(t, http.StatusUnauthorized, rr.Code) req.SetBasicAuth("test1", "password1") rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) httpAuth, err = common.NewBasicAuthProvider("") require.NoError(t, err) req, err = http.NewRequest(http.MethodGet, "/metrics", nil) require.NoError(t, err) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) err = os.Remove(authUserFile) require.NoError(t, err) } ================================================ FILE: internal/util/errors.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package util import ( "errors" "fmt" ) const ( templateLoadErrorHints = "Try setting the absolute templates path in your configuration file " + "or specifying the config directory adding the `-c` flag to the serve options. For example: " + "sftpgo serve -c \"\"" ) // MaxRecursion defines the maximum number of allowed recursions const MaxRecursion = 1000 // errors definitions var ( ErrValidation = NewValidationError("") ErrNotFound = NewRecordNotFoundError("") ErrMethodDisabled = NewMethodDisabledError("") ErrGeneric = NewGenericError("") ErrRecursionTooDeep = errors.New("recursion too deep") ) // ValidationError raised if input data is not valid type ValidationError struct { err string } // Validation error details func (e *ValidationError) Error() string { return fmt.Sprintf("Validation error: %s", e.err) } // GetErrorString returns the unmodified error string func (e *ValidationError) GetErrorString() string { return e.err } // Is reports if target matches func (e *ValidationError) Is(target error) bool { _, ok := target.(*ValidationError) return ok } // NewValidationError returns a validation errors func NewValidationError(errorString string) *ValidationError { return &ValidationError{ err: errorString, } } // RecordNotFoundError raised if a requested object is not found type RecordNotFoundError struct { err string } func (e *RecordNotFoundError) Error() string { return fmt.Sprintf("not found: %s", e.err) } // Is reports if target matches func (e *RecordNotFoundError) Is(target error) bool { _, ok := target.(*RecordNotFoundError) return ok } // NewRecordNotFoundError returns a not found error func NewRecordNotFoundError(errorString string) *RecordNotFoundError { return &RecordNotFoundError{ err: errorString, } } // MethodDisabledError raised if a method is disabled in config file. // For example, if user management is disabled, this error is raised // every time a user operation is done using the REST API type MethodDisabledError struct { err string } // Method disabled error details func (e *MethodDisabledError) Error() string { return fmt.Sprintf("Method disabled error: %s", e.err) } // Is reports if target matches func (e *MethodDisabledError) Is(target error) bool { _, ok := target.(*MethodDisabledError) return ok } // NewMethodDisabledError returns a method disabled error func NewMethodDisabledError(errorString string) *MethodDisabledError { return &MethodDisabledError{ err: errorString, } } // GenericError raised for not well categorized error type GenericError struct { err string } func (e *GenericError) Error() string { return e.err } // Is reports if target matches func (e *GenericError) Is(target error) bool { _, ok := target.(*GenericError) return ok } // NewGenericError returns a generic error func NewGenericError(errorString string) *GenericError { return &GenericError{ err: errorString, } } ================================================ FILE: internal/util/i18n.go ================================================ // Copyright (C) 2023 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package util import ( "encoding/json" "errors" ) // localization id for the Web frontend const ( I18nSetupTitle = "title.setup" I18nLoginTitle = "title.login" I18nShareLoginTitle = "title.share_login" I18nFilesTitle = "title.files" I18nSharesTitle = "title.shares" I18nShareAddTitle = "title.add_share" I18nShareUpdateTitle = "title.update_share" I18nProfileTitle = "title.profile" I18nUsersTitle = "title.users" I18nGroupsTitle = "title.groups" I18nFoldersTitle = "title.folders" I18nChangePwdTitle = "title.change_password" I18n2FATitle = "title.two_factor_auth" I18nEditFileTitle = "title.edit_file" I18nViewFileTitle = "title.view_file" I18nForgotPwdTitle = "title.recovery_password" I18nResetPwdTitle = "title.reset_password" I18nSharedFilesTitle = "title.shared_files" I18nShareUploadTitle = "title.upload_to_share" I18nShareDownloadTitle = "title.download_shared_file" I18nShareAccessErrorTitle = "title.share_access_error" I18nInvalidAuthReqTitle = "title.invalid_auth_request" I18nError403Title = "title.error403" I18nError400Title = "title.error400" I18nError404Title = "title.error404" I18nError416Title = "title.error416" I18nError429Title = "title.error429" I18nError500Title = "title.error500" I18nErrorPDFTitle = "title.errorPDF" I18nErrorEditorTitle = "title.error_editor" I18nAddUserTitle = "title.add_user" I18nUpdateUserTitle = "title.update_user" I18nAddAdminTitle = "title.add_admin" I18nUpdateAdminTitle = "title.update_admin" I18nTemplateUserTitle = "title.template_user" I18nMaintenanceTitle = "title.maintenance" I18nConfigsTitle = "title.configs" I18nOAuth2Title = "title.oauth2_success" I18nOAuth2ErrorTitle = "title.oauth2_error" I18nSessionsTitle = "title.connections" I18nRolesTitle = "title.roles" I18nAdminsTitle = "title.admins" I18nIPListsTitle = "title.ip_lists" I18nAddIPListTitle = "title.add_ip_list" I18nUpdateIPListTitle = "title.update_ip_list" I18nDefenderTitle = "title.defender" I18nEventsTitle = "title.logs" I18nActionsTitle = "title.event_actions" I18nRulesTitle = "title.event_rules" I18nAddActionTitle = "title.add_action" I18nUpdateActionTitle = "title.update_action" I18nAddRuleTitle = "title.add_rule" I18nUpdateRuleTitle = "title.update_rule" I18nStatusTitle = "status.desc" I18nErrorSetupInstallCode = "setup.install_code_mismatch" I18nInvalidAuth = "general.invalid_auth_request" I18nError429Message = "general.error429" I18nError400Message = "general.error400" I18nError403Message = "general.error403" I18nError404Message = "general.error404" I18nError416Message = "general.error416" I18nError500Message = "general.error500" I18nErrorPDFMessage = "general.errorPDF" I18nErrorInvalidToken = "general.invalid_token" I18nErrorInvalidForm = "general.invalid_form" I18nErrorInvalidCredentials = "general.invalid_credentials" I18nErrorInvalidCSRF = "general.invalid_csrf" I18nErrorFsGeneric = "fs.err_generic" I18nErrorDirListGeneric = "fs.dir_list.err_generic" I18nErrorDirList403 = "fs.dir_list.err_403" I18nErrorDirList429 = "fs.dir_list.err_429" I18nErrorDirListUser = "fs.dir_list.err_user" I18nErrorFsValidation = "fs.err_validation" I18nErrorChangePwdRequiredFields = "change_pwd.required_fields" I18nErrorChangePwdNoMatch = "change_pwd.no_match" I18nErrorChangePwdGeneric = "change_pwd.generic" I18nErrorChangePwdNoDifferent = "change_pwd.no_different" I18nErrorChangePwdCurrentNoMatch = "change_pwd.current_no_match" I18nErrorChangePwdRequired = "change_pwd.required" I18nErrorUsernameRequired = "general.username_required" I18nErrorPasswordRequired = "general.password_required" I18nErrorPermissionsRequired = "general.permissions_required" I18nErrorGetUser = "general.err_user" I18nErrorPwdResetForbidded = "login.reset_pwd_forbidden" I18nErrorPwdResetNoEmail = "login.reset_pwd_no_email" I18nErrorPwdResetSendEmail = "login.reset_pwd_send_email_err" I18nErrorPwdResetGeneric = "login.reset_pwd_err_generic" I18nErrorProtocolForbidden = "general.err_protocol_forbidden" I18nErrorPwdLoginForbidden = "general.pwd_login_forbidden" I18nErrorIPForbidden = "general.ip_forbidden" I18nErrorConnectionForbidden = "general.connection_forbidden" I18nErrorReservedUsername = "user.username_reserved" I18nErrorInvalidEmail = "general.email_invalid" I18nErrorInvalidInput = "general.invalid_input" I18nErrorInvalidUser = "user.username_invalid" I18nErrorInvalidName = "general.name_invalid" I18nErrorHomeRequired = "user.home_required" I18nErrorHomeInvalid = "user.home_invalid" I18nErrorPubKeyInvalid = "user.pub_key_invalid" I18nErrorPrivKeyInvalid = "user.priv_key_invalid" I18nErrorKeySizeInvalid = "user.key_invalid_size" I18nErrorKeyInsecure = "user.key_insecure" I18nErrorPrimaryGroup = "user.err_primary_group" I18nErrorDuplicateGroup = "user.err_duplicate_group" I18nErrorNoPermission = "user.no_permissions" I18nErrorNoRootPermission = "user.no_root_permissions" I18nErrorGenericPermission = "user.err_permissions_generic" I18nError2FAInvalid = "user.2fa_invalid" I18nErrorRecoveryCodesInvalid = "user.recovery_codes_invalid" I18nErrorFolderNameRequired = "general.foldername_required" I18nErrorFolderMountPathRequired = "user.folder_path_required" I18nErrorDuplicatedFolders = "user.folder_duplicated" I18nErrorOverlappedFolders = "user.folder_overlapped" I18nErrorFolderQuotaSizeInvalid = "user.folder_quota_size_invalid" I18nErrorFolderQuotaFileInvalid = "user.folder_quota_file_invalid" I18nErrorFolderQuotaInvalid = "user.folder_quota_invalid" I18nErrorPasswordComplexity = "general.err_password_complexity" I18nErrorIPFiltersInvalid = "user.ip_filters_invalid" I18nErrorSourceBWLimitInvalid = "user.src_bw_limits_invalid" I18nErrorShareExpirationInvalid = "user.share_expiration_invalid" I18nErrorFilePatternPathInvalid = "user.file_pattern_path_invalid" I18nErrorFilePatternDuplicated = "user.file_pattern_duplicated" I18nErrorFilePatternInvalid = "user.file_pattern_invalid" I18nErrorDisableActive2FA = "user.disable_active_2fa" I18nErrorPwdChangeConflict = "user.pwd_change_conflict" I18nError2FAConflict = "user.two_factor_conflict" I18nErrorLoginAfterReset = "login.reset_ok_login_error" I18nErrorShareScope = "share.scope_invalid" I18nErrorShareMaxTokens = "share.max_tokens_invalid" I18nErrorShareExpiration = "share.expiration_invalid" I18nErrorShareNoPwd = "share.err_no_password" I18nErrorShareExpirationOutOfRange = "share.expiration_out_of_range" I18nErrorShareGeneric = "share.generic" I18nErrorNameRequired = "general.name_required" I18nErrorSharePathRequired = "share.path_required" I18nErrorShareWriteScope = "share.path_write_scope" I18nErrorShareNestedPaths = "share.nested_paths" I18nErrorShareExpirationPast = "share.expiration_past" I18nErrorInvalidIPMask = "general.allowed_ip_mask_invalid" I18nErrorShareUsage = "share.usage_exceed" I18nErrorShareExpired = "share.expired" I18nErrorLoginFromIPDenied = "login.ip_not_allowed" I18nError2FARequired = "login.two_factor_required" I18nError2FARequiredGeneric = "login.two_factor_required_generic" I18nErrorNoOIDCFeature = "general.no_oidc_feature" I18nErrorNoPermissions = "general.no_permissions" I18nErrorShareBrowsePaths = "share.browsable_multiple_paths" I18nErrorShareBrowseNoDir = "share.browsable_non_dir" I18nErrorShareInvalidPath = "share.invalid_path" I18nErrorPathInvalid = "general.path_invalid" I18nErrorQuotaRead = "general.err_quota_read" I18nErrorEditDir = "general.error_edit_dir" I18nErrorEditSize = "general.error_edit_size" I18nProfileUpdated = "general.profile_updated" I18nShareLoginOK = "general.share_ok" I18n2FADisabled = "2fa.disabled" I18nOIDCTokenExpired = "oidc.token_expired" I18nOIDCTokenInvalidAdmin = "oidc.token_invalid_webadmin" I18nOIDCTokenInvalidUser = "oidc.token_invalid_webclient" I18nOIDCErrTokenExchange = "oidc.token_exchange_err" I18nOIDCTokenInvalid = "oidc.token_invalid" I18nOIDCTokenInvalidRoleAdmin = "oidc.role_admin_err" I18nOIDCTokenInvalidRoleUser = "oidc.role_user_err" I18nOIDCErrGetUser = "oidc.get_user_err" I18nErrorInvalidQuotaSize = "user.invalid_quota_size" I18nErrorTimeOfDayInvalid = "user.time_of_day_invalid" I18nErrorTimeOfDayConflict = "user.time_of_day_conflict" I18nErrorInvalidMaxFilesize = "filters.max_upload_size_invalid" I18nErrorInvalidHomeDir = "storage.home_dir_invalid" I18nErrorBucketRequired = "storage.bucket_required" I18nErrorRegionRequired = "storage.region_required" I18nErrorKeyPrefixInvalid = "storage.key_prefix_invalid" I18nErrorULPartSizeInvalid = "storage.ul_part_size_invalid" I18nErrorDLPartSizeInvalid = "storage.dl_part_size_invalid" I18nErrorULConcurrencyInvalid = "storage.ul_concurrency_invalid" I18nErrorDLConcurrencyInvalid = "storage.dl_concurrency_invalid" I18nErrorAccessKeyRequired = "storage.access_key_required" I18nErrorAccessSecretRequired = "storage.access_secret_required" I18nErrorFsCredentialsRequired = "storage.credentials_required" I18nErrorContainerRequired = "storage.container_required" I18nErrorAccountNameRequired = "storage.account_name_required" I18nErrorSASURLInvalid = "storage.sas_url_invalid" I18nErrorPassphraseRequired = "storage.passphrase_required" I18nErrorEndpointInvalid = "storage.endpoint_invalid" I18nErrorEndpointRequired = "storage.endpoint_required" I18nErrorFsUsernameRequired = "storage.username_required" I18nAddGroupTitle = "title.add_group" I18nUpdateGroupTitle = "title.update_group" I18nRoleAddTitle = "title.add_role" I18nRoleUpdateTitle = "title.update_role" I18nErrorInvalidTLSCert = "user.tls_cert_invalid" I18nAddFolderTitle = "title.add_folder" I18nUpdateFolderTitle = "title.update_folder" I18nTemplateFolderTitle = "title.template_folder" I18nErrorDuplicatedUsername = "general.duplicated_username" I18nErrorDuplicatedName = "general.duplicated_name" I18nErrorDuplicatedIPNet = "ip_list.duplicated" I18nErrorRoleAdminPerms = "admin.role_permissions" I18nBackupOK = "maintenance.backup_ok" I18nErrorFolderTemplate = "virtual_folders.template_no_folder" I18nErrorUserTemplate = "user.template_no_user" I18nConfigsOK = "general.configs_saved" I18nOAuth2ErrorVerifyState = "oauth2.auth_verify_error" I18nOAuth2ErrorValidateState = "oauth2.auth_validation_error" I18nOAuth2InvalidState = "oauth2.auth_invalid" I18nOAuth2ErrTokenExchange = "oauth2.token_exchange_err" I18nOAuth2ErrNoRefreshToken = "oauth2.no_refresh_token" I18nOAuth2OK = "oauth2.success" I18nErrorAdminSelfPerms = "admin.self_permissions" I18nErrorAdminSelfDisable = "admin.self_disable" I18nErrorAdminSelfRole = "admin.self_role" I18nErrorIPInvalid = "ip_list.ip_invalid" I18nErrorNetInvalid = "ip_list.net_invalid" I18nFTPTLSDisabled = "status.tls_disabled" I18nFTPTLSExplicit = "status.tls_explicit" I18nFTPTLSImplicit = "status.tls_implicit" I18nFTPTLSMixed = "status.tls_mixed" I18nErrorBackupFile = "maintenance.backup_invalid_file" I18nErrorRestore = "maintenance.restore_error" I18nErrorACMEGeneric = "acme.generic_error" I18nErrorSMTPRequiredFields = "smtp.err_required_fields" I18nErrorClientIDRequired = "oauth2.client_id_required" I18nErrorClientSecretRequired = "oauth2.client_secret_required" I18nErrorRefreshTokenRequired = "oauth2.refresh_token_required" I18nErrorURLRequired = "actions.http_url_required" I18nErrorURLInvalid = "actions.http_url_invalid" I18nErrorHTTPPartNameRequired = "actions.http_part_name_required" I18nErrorHTTPPartBodyRequired = "actions.http_part_body_required" I18nErrorMultipartBody = "actions.http_multipart_body_error" I18nErrorMultipartCType = "actions.http_multipart_ctype_error" I18nErrorPathDuplicated = "actions.path_duplicated" I18nErrorCommandRequired = "actions.command_required" I18nErrorCommandInvalid = "actions.command_invalid" I18nErrorEmailRecipientRequired = "actions.email_recipient_required" I18nErrorEmailSubjectRequired = "actions.email_subject_required" I18nErrorEmailBodyRequired = "actions.email_body_required" I18nErrorRetentionDirRequired = "actions.retention_directory_required" I18nErrorPathRequired = "actions.path_required" I18nErrorSourceDestMatch = "actions.source_dest_different" I18nErrorRootNotAllowed = "actions.root_not_allowed" I18nErrorArchiveNameRequired = "actions.archive_name_required" I18nErrorIDPTemplateRequired = "actions.idp_template_required" I18nActionTypeHTTP = "actions.types.http" I18nActionTypeEmail = "actions.types.email" I18nActionTypeBackup = "actions.types.backup" I18nActionTypeUserQuotaReset = "actions.types.user_quota_reset" I18nActionTypeFolderQuotaReset = "actions.types.folder_quota_reset" I18nActionTypeTransferQuotaReset = "actions.types.transfer_quota_reset" I18nActionTypeDataRetentionCheck = "actions.types.data_retention_check" I18nActionTypeFilesystem = "actions.types.filesystem" I18nActionTypePwdExpirationCheck = "actions.types.password_expiration_check" I18nActionTypeUserExpirationCheck = "actions.types.user_expiration_check" I18nActionTypeUserInactivityCheck = "actions.types.user_inactivity_check" I18nActionTypeIDPCheck = "actions.types.idp_check" I18nActionTypeCommand = "actions.types.command" I18nActionTypeRotateLogs = "actions.types.rotate_logs" I18nActionFsTypeRename = "actions.fs_types.rename" I18nActionFsTypeDelete = "actions.fs_types.delete" I18nActionFsTypePathExists = "actions.fs_types.path_exists" I18nActionFsTypeCompress = "actions.fs_types.compress" I18nActionFsTypeCopy = "actions.fs_types.copy" I18nActionFsTypeCreateDirs = "actions.fs_types.create_dirs" I18nActionThresholdRequired = "actions.inactivity_threshold_required" I18nActionThresholdsInvalid = "actions.inactivity_thresholds_invalid" I18nTriggerFsEvent = "rules.triggers.fs_event" I18nTriggerProviderEvent = "rules.triggers.provider_event" I18nTriggerIPBlockedEvent = "rules.triggers.ip_blocked" I18nTriggerCertificateRenewEvent = "rules.triggers.certificate_renewal" I18nTriggerOnDemandEvent = "rules.triggers.on_demand" I18nTriggerIDPLoginEvent = "rules.triggers.idp_login" I18nTriggerScheduleEvent = "rules.triggers.schedule" I18nErrorInvalidMinSize = "rules.invalid_fs_min_size" I18nErrorInvalidMaxSize = "rules.invalid_fs_max_size" I18nErrorRuleActionRequired = "rules.action_required" I18nErrorRuleFsEventRequired = "rules.fs_event_required" I18nErrorRuleProviderEventRequired = "rules.provider_event_required" I18nErrorRuleScheduleRequired = "rules.schedule_required" I18nErrorRuleScheduleInvalid = "rules.schedule_invalid" I18nErrorRuleDuplicateActions = "rules.duplicate_actions" I18nErrorEvSyncFailureActions = "rules.sync_failure_actions" I18nErrorEvSyncUnsupported = "rules.sync_unsupported" I18nErrorEvSyncUnsupportedFs = "rules.sync_unsupported_fs_event" I18nErrorRuleFailureActionsOnly = "rules.only_failure_actions" I18nErrorRuleSyncActionRequired = "rules.sync_action_required" I18nErrorInvalidPNG = "branding.invalid_png" I18nErrorInvalidPNGSize = "branding.invalid_png_size" I18nErrorInvalidDisclaimerURL = "branding.invalid_disclaimer_url" ) // NewI18nError returns a I18nError wrappring the provided error func NewI18nError(err error, message string, options ...I18nErrorOption) *I18nError { var errI18n *I18nError if errors.As(err, &errI18n) { return errI18n } errI18n = &I18nError{ err: err, Message: message, args: nil, } for _, opt := range options { opt(errI18n) } return errI18n } // I18nErrorOption defines a functional option type that allows to configure the I18nError. type I18nErrorOption func(*I18nError) // I18nErrorArgs is a functional option to set I18nError arguments. func I18nErrorArgs(args map[string]any) I18nErrorOption { return func(e *I18nError) { e.args = args } } // I18nError is an error wrapper that add a message to use for localization. type I18nError struct { err error Message string args map[string]any } // Error returns the wrapped error string. func (e *I18nError) Error() string { return e.err.Error() } // Unwrap returns the underlying error func (e *I18nError) Unwrap() error { return e.err } // Is reports if target matches func (e *I18nError) Is(target error) bool { if errors.Is(e.err, target) { return true } _, ok := target.(*I18nError) return ok } // HasArgs returns true if the error has i18n args. func (e *I18nError) HasArgs() bool { return len(e.args) > 0 } // Args returns the provided args in JSON format func (e *I18nError) Args() string { if len(e.args) > 0 { data, err := json.Marshal(e.args) if err == nil { return BytesToString(data) } } return "{}" } ================================================ FILE: internal/util/resources.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !bundle package util import ( "html/template" "os" "path/filepath" "runtime" "github.com/drakkan/sftpgo/v2/internal/logger" ) // FindSharedDataPath searches for the specified directory name in searchDir // and in system-wide shared data directories. // If name is an absolute path it is returned unmodified. func FindSharedDataPath(name, searchDir string) string { if !IsFileInputValid(name) { return "" } if name != "" && !filepath.IsAbs(name) { searchList := []string{searchDir} if additionalSharedDataSearchPath != "" { searchList = append(searchList, additionalSharedDataSearchPath) } if runtime.GOOS != osWindows { searchList = append(searchList, "/usr/share/sftpgo") searchList = append(searchList, "/usr/local/share/sftpgo") } searchList = RemoveDuplicates(searchList, false) for _, basePath := range searchList { res := filepath.Join(basePath, name) _, err := os.Stat(res) if err == nil { logger.Debug(logSender, "", "found share data path for name %q: %q", name, res) return res } } return filepath.Join(searchDir, name) } return name } // LoadTemplate parses the given template paths. // It behaves like template.Must but it writes a log before exiting. func LoadTemplate(base *template.Template, paths ...string) *template.Template { if base != nil { baseTmpl, err := base.Clone() if err != nil { showTemplateLoadingError(err) } t, err := baseTmpl.ParseFiles(paths...) if err != nil { showTemplateLoadingError(err) } return t } t, err := template.ParseFiles(paths...) if err != nil { showTemplateLoadingError(err) } return t } func showTemplateLoadingError(err error) { logger.ErrorToConsole("error loading required template: %v", err) logger.ErrorToConsole(templateLoadErrorHints) logger.Error(logSender, "", "error loading required template: %v", err) os.Exit(1) } ================================================ FILE: internal/util/resources_embedded.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build bundle package util import ( "html/template" "os" "github.com/drakkan/sftpgo/v2/internal/bundle" "github.com/drakkan/sftpgo/v2/internal/logger" ) // FindSharedDataPath searches for the specified directory name in searchDir // and in system-wide shared data directories. // If name is an absolute path it is returned unmodified. func FindSharedDataPath(name, _ string) string { return name } // LoadTemplate parses the given template paths. // It behaves like template.Must but it writes a log before exiting. // You can optionally provide a base template (e.g. to define some custom functions) func LoadTemplate(base *template.Template, paths ...string) *template.Template { var t *template.Template var err error templateFs := bundle.GetTemplatesFs() if base != nil { base, err = base.Clone() if err == nil { t, err = base.ParseFS(templateFs, paths...) } } else { t, err = template.ParseFS(templateFs, paths...) } if err != nil { logger.ErrorToConsole("error loading required template: %v", err) logger.ErrorToConsole(templateLoadErrorHints) logger.Error(logSender, "", "error loading required template: %v", err) os.Exit(1) } return t } ================================================ FILE: internal/util/util.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package util provides some common utility methods package util import ( "bytes" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/subtle" "crypto/tls" "crypto/x509" "encoding/hex" "encoding/json" "encoding/pem" "errors" "fmt" "hash" "io" "io/fs" "math" "net" "net/http" "net/netip" "net/url" "os" "path" "path/filepath" "regexp" "runtime" "slices" "strconv" "strings" "time" "unicode" "unsafe" "github.com/google/uuid" "github.com/lithammer/shortuuid/v4" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/logger" ) const ( logSender = "util" osWindows = "windows" pubKeySuffix = ".pub" ) var ( emailRegex = regexp.MustCompile("^(?:(?:(?:(?:[a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+(?:\\.([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+)*)|(?:(?:\\x22)(?:(?:(?:(?:\\x20|\\x09)*(?:\\x0d\\x0a))?(?:\\x20|\\x09)+)?(?:(?:[\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x7f]|\\x21|[\\x23-\\x5b]|[\\x5d-\\x7e]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[\\x01-\\x09\\x0b\\x0c\\x0d-\\x7f]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}]))))*(?:(?:(?:\\x20|\\x09)*(?:\\x0d\\x0a))?(\\x20|\\x09)+)?(?:\\x22))))@(?:(?:(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])(?:[a-zA-Z]|\\d|-|\\.|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.)+(?:(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])(?:[a-zA-Z]|\\d|-|\\.|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.?$") // this can be set at build time additionalSharedDataSearchPath = "" // CertsBasePath defines base path for certificates obtained using the built-in ACME protocol. // It is empty is ACME support is disabled CertsBasePath string // Defines the TLS ciphers used by default for TLS 1.0-1.2 if no preference is specified. defaultTLSCiphers = []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, } ) // IEC Sizes. // kibis of bits const ( oneByte = 1 << (iota * 10) kiByte miByte giByte tiByte piByte eiByte ) // SI Sizes. const ( iByte = 1 kbByte = iByte * 1000 mByte = kbByte * 1000 gByte = mByte * 1000 tByte = gByte * 1000 pByte = tByte * 1000 eByte = pByte * 1000 ) var bytesSizeTable = map[string]uint64{ "b": oneByte, "kib": kiByte, "kb": kbByte, "mib": miByte, "mb": mByte, "gib": giByte, "gb": gByte, "tib": tiByte, "tb": tByte, "pib": piByte, "pb": pByte, "eib": eiByte, "eb": eByte, // Without suffix "": oneByte, "ki": kiByte, "k": kbByte, "mi": miByte, "m": mByte, "gi": giByte, "g": gByte, "ti": tiByte, "t": tByte, "pi": piByte, "p": pByte, "ei": eiByte, "e": eByte, } // IsStringPrefixInSlice searches a string prefix in a slice and returns true // if a matching prefix is found func IsStringPrefixInSlice(obj string, list []string) bool { for i := 0; i < len(list); i++ { if strings.HasPrefix(obj, list[i]) { return true } } return false } // RemoveDuplicates returns a new slice removing any duplicate element from the initial one func RemoveDuplicates(obj []string, trim bool) []string { if len(obj) == 0 { return obj } seen := make(map[string]bool) validIdx := 0 for _, item := range obj { if trim { item = strings.TrimSpace(item) } if !seen[item] { seen[item] = true obj[validIdx] = item validIdx++ } } return obj[:validIdx] } // IsNameValid validates that a name/username contains only safe characters. func IsNameValid(name string) bool { if name == "" { return false } if len(name) > 255 { return false } for _, r := range name { if unicode.IsControl(r) { return false } switch r { case '/', '\\': return false case ':', '*', '?', '"', '<', '>', '|': return false } } if name == "." || name == ".." { return false } upperName := strings.ToUpper(name) baseName := strings.Split(upperName, ".")[0] switch baseName { case "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9": return false } if strings.HasSuffix(name, " ") || strings.HasSuffix(name, ".") { return false } return true } // GetTimeAsMsSinceEpoch returns unix timestamp as milliseconds from a time struct func GetTimeAsMsSinceEpoch(t time.Time) int64 { return t.UnixMilli() } // GetTimeFromMsecSinceEpoch return a time struct from a unix timestamp with millisecond precision func GetTimeFromMsecSinceEpoch(msec int64) time.Time { return time.Unix(0, msec*1000000) } // GetDurationAsString returns a string representation for a time.Duration func GetDurationAsString(d time.Duration) string { d = d.Round(time.Second) h := d / time.Hour d -= h * time.Hour m := d / time.Minute d -= m * time.Minute s := d / time.Second if h > 0 { return fmt.Sprintf("%02d:%02d:%02d", h, m, s) } return fmt.Sprintf("%02d:%02d", m, s) } // ByteCountSI returns humanized size in SI (decimal) format func ByteCountSI(b int64) string { return byteCount(b, 1000, true) } // ByteCountIEC returns humanized size in IEC (binary) format func ByteCountIEC(b int64) string { return byteCount(b, 1024, false) } func byteCount(b int64, unit int64, maxPrecision bool) string { if b <= 0 && maxPrecision { return strconv.FormatInt(b, 10) } if b < unit { return fmt.Sprintf("%d B", b) } div, exp := unit, 0 for n := b / unit; n >= unit; n /= unit { div *= unit exp++ } var val string if maxPrecision { val = strconv.FormatFloat(float64(b)/float64(div), 'f', -1, 64) } else { val = fmt.Sprintf("%.1f", float64(b)/float64(div)) } if unit == 1000 { return fmt.Sprintf("%s %cB", val, "KMGTPE"[exp]) } return fmt.Sprintf("%s %ciB", val, "KMGTPE"[exp]) } // ParseBytes parses a string representation of bytes into the number // of bytes it represents. // // ParseBytes("42 MB") -> 42000000, nil // ParseBytes("42 mib") -> 44040192, nil // // copied from here: // // https://github.com/dustin/go-humanize/blob/master/bytes.go // // with minor modifications func ParseBytes(s string) (int64, error) { s = strings.TrimSpace(s) lastDigit := 0 hasComma := false for _, r := range s { if !unicode.IsDigit(r) && r != '.' && r != ',' { break } if r == ',' { hasComma = true } lastDigit++ } num := s[:lastDigit] if hasComma { num = strings.ReplaceAll(num, ",", "") } f, err := strconv.ParseFloat(num, 64) if err != nil { return 0, err } extra := strings.ToLower(strings.TrimSpace(s[lastDigit:])) if m, ok := bytesSizeTable[extra]; ok { f *= float64(m) if f >= math.MaxInt64 { return 0, fmt.Errorf("value too large: %v", s) } if f < 0 { return 0, fmt.Errorf("negative value not allowed: %v", s) } return int64(f), nil } return 0, fmt.Errorf("unhandled size name: %v", extra) } // BytesToString converts []byte to string without allocations. // https://github.com/kubernetes/kubernetes/blob/e4b74dd12fa8cb63c174091d5536a10b8ec19d34/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go#L278 // Use only if strictly required, this method uses unsafe. func BytesToString(b []byte) string { // unsafe.SliceData relies on cap whereas we want to rely on len if len(b) == 0 { return "" } // https://github.com/golang/go/blob/4ed358b57efdad9ed710be7f4fc51495a7620ce2/src/strings/builder.go#L41 return unsafe.String(unsafe.SliceData(b), len(b)) } // StringToBytes convert string to []byte without allocations. // https://github.com/kubernetes/kubernetes/blob/e4b74dd12fa8cb63c174091d5536a10b8ec19d34/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go#L289 // Use only if strictly required, this method uses unsafe. func StringToBytes(s string) []byte { // unsafe.StringData is unspecified for the empty string, so we provide a strict interpretation if s == "" { return nil } // https://github.com/golang/go/blob/4ed358b57efdad9ed710be7f4fc51495a7620ce2/src/os/file.go#L300 return unsafe.Slice(unsafe.StringData(s), len(s)) } // GetIPFromRemoteAddress returns the IP from the remote address. // If the given remote address cannot be parsed it will be returned unchanged func GetIPFromRemoteAddress(remoteAddress string) string { ip, _, err := net.SplitHostPort(remoteAddress) if err == nil { return ip } return remoteAddress } // GetIPFromNetAddr returns the IP from the network address func GetIPFromNetAddr(upstream net.Addr) (net.IP, error) { if upstream == nil { return nil, errors.New("invalid address") } upstreamString, _, err := net.SplitHostPort(upstream.String()) if err != nil { return nil, err } upstreamIP := net.ParseIP(upstreamString) if upstreamIP == nil { return nil, fmt.Errorf("invalid IP address: %q", upstreamString) } return upstreamIP, nil } // NilIfEmpty returns nil if the input string is empty func NilIfEmpty(s string) *string { if s == "" { return nil } return &s } // GetStringFromPointer returns the string value or empty if nil func GetStringFromPointer(val *string) string { if val == nil { return "" } return *val } // GetIntFromPointer returns the int value or zero func GetIntFromPointer(val *int64) int64 { if val == nil { return 0 } return *val } // GetTimeFromPointer returns the time value or now func GetTimeFromPointer(val *time.Time) time.Time { if val == nil { return time.Unix(0, 0) } return *val } // GenerateRSAKeys generate rsa private and public keys and write the // private key to specified file and the public key to the specified // file adding the .pub suffix func GenerateRSAKeys(file string) error { if err := createDirPathIfMissing(file, 0700); err != nil { return err } key, err := rsa.GenerateKey(rand.Reader, 3072) if err != nil { return err } o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } defer o.Close() priv := &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), } if err := pem.Encode(o, priv); err != nil { return err } pub, err := ssh.NewPublicKey(&key.PublicKey) if err != nil { return err } return os.WriteFile(file+pubKeySuffix, ssh.MarshalAuthorizedKey(pub), 0600) } // GenerateECDSAKeys generate ecdsa private and public keys and write the // private key to specified file and the public key to the specified // file adding the .pub suffix func GenerateECDSAKeys(file string) error { if err := createDirPathIfMissing(file, 0700); err != nil { return err } key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return err } keyBytes, err := x509.MarshalECPrivateKey(key) if err != nil { return err } priv := &pem.Block{ Type: "EC PRIVATE KEY", Bytes: keyBytes, } o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } defer o.Close() if err := pem.Encode(o, priv); err != nil { return err } pub, err := ssh.NewPublicKey(&key.PublicKey) if err != nil { return err } return os.WriteFile(file+pubKeySuffix, ssh.MarshalAuthorizedKey(pub), 0600) } // GenerateEd25519Keys generate ed25519 private and public keys and write the // private key to specified file and the public key to the specified // file adding the .pub suffix func GenerateEd25519Keys(file string) error { pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) if err != nil { return err } keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey) if err != nil { return err } priv := &pem.Block{ Type: "PRIVATE KEY", Bytes: keyBytes, } o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } defer o.Close() if err := pem.Encode(o, priv); err != nil { return err } pub, err := ssh.NewPublicKey(pubKey) if err != nil { return err } return os.WriteFile(file+pubKeySuffix, ssh.MarshalAuthorizedKey(pub), 0600) } // IsDirOverlapped returns true if dir1 and dir2 overlap func IsDirOverlapped(dir1, dir2 string, fullCheck bool, separator string) bool { if dir1 == dir2 { return true } if fullCheck { if len(dir1) > len(dir2) { if strings.HasPrefix(dir1, dir2+separator) { return true } } if len(dir2) > len(dir1) { if strings.HasPrefix(dir2, dir1+separator) { return true } } } return false } // GetDirsForVirtualPath returns all the directory for the given path in reverse order // for example if the path is: /1/2/3/4 it returns: // [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ] func GetDirsForVirtualPath(virtualPath string) []string { if virtualPath == "" || virtualPath == "." { virtualPath = "/" } else { if !path.IsAbs(virtualPath) { virtualPath = CleanPath(virtualPath) } } dirsForPath := []string{virtualPath} for virtualPath != "/" { virtualPath = path.Dir(virtualPath) dirsForPath = append(dirsForPath, virtualPath) } return dirsForPath } // CleanPath returns a clean POSIX (/) absolute path to work with func CleanPath(p string) string { return CleanPathWithBase("/", p) } // CleanPathWithBase returns a clean POSIX (/) absolute path to work with. // The specified base will be used if the provided path is not absolute func CleanPathWithBase(base, p string) string { p = strings.ReplaceAll(p, "\\", "/") if !path.IsAbs(p) { p = path.Join(base, p) } return path.Clean(p) } // IsFileInputValid returns true this is a valid file name. // This method must be used before joining a file name, generally provided as // user input, with a directory func IsFileInputValid(fileInput string) bool { cleanInput := filepath.Clean(fileInput) if cleanInput == "." || cleanInput == ".." { return false } return true } // CleanDirInput sanitizes user input for directories. // On Windows it removes any trailing `"`. // We try to help windows users that set an invalid path such as "C:\ProgramData\SFTPGO\". // This will only help if the invalid path is the last argument, for example in this command: // sftpgo.exe serve -c "C:\ProgramData\SFTPGO\" -l "sftpgo.log" // the -l flag will be ignored and the -c flag will get the value `C:\ProgramData\SFTPGO" -l sftpgo.log` // since the backslash after SFTPGO escape the double quote. This is definitely a bad user input func CleanDirInput(dirInput string) string { if runtime.GOOS == osWindows { for strings.HasSuffix(dirInput, "\"") { dirInput = strings.TrimSuffix(dirInput, "\"") } } return filepath.Clean(dirInput) } func createDirPathIfMissing(file string, perm os.FileMode) error { dirPath := filepath.Dir(file) if _, err := os.Stat(dirPath); errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(dirPath, perm) if err != nil { return err } } return nil } // GenerateRandomBytes generates random bytes with the specified length func GenerateRandomBytes(length int) []byte { b := make([]byte, length) _, err := io.ReadFull(rand.Reader, b) if err != nil { PanicOnError(fmt.Errorf("failed to read random data (see https://go.dev/issue/66821): %w", err)) } return b } // GenerateOpaqueString generates a cryptographically secure opaque string func GenerateOpaqueString() string { randomBytes := sha256.Sum256(GenerateRandomBytes(32)) return hex.EncodeToString(randomBytes[:]) } // GenerateUniqueID returns an unique ID func GenerateUniqueID() string { u, err := uuid.NewRandom() if err != nil { PanicOnError(fmt.Errorf("failed to read random data (see https://go.dev/issue/66821): %w", err)) } return shortuuid.DefaultEncoder.Encode(u) } // HTTPListenAndServe is a wrapper for ListenAndServe that support both tcp // and Unix-domain sockets func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool, listenerWrapper func(net.Listener) (net.Listener, error), logSender string, ) error { var listener net.Listener var err error if filepath.IsAbs(address) && runtime.GOOS != osWindows { if !IsFileInputValid(address) { return fmt.Errorf("invalid socket address %q", address) } err = createDirPathIfMissing(address, 0770) if err != nil { logger.ErrorToConsole("error creating Unix-domain socket parent dir: %v", err) logger.Error(logSender, "", "error creating Unix-domain socket parent dir: %v", err) } os.Remove(address) listener, err = net.Listen("unix", address) if err == nil { // should a chmod err be fatal? if errChmod := os.Chmod(address, 0770); errChmod != nil { logger.Warn(logSender, "", "unable to set the Unix-domain socket group writable: %v", errChmod) } } } else { CheckTCP4Port(port) listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", address, port)) } if err != nil { return err } if listenerWrapper != nil { listener, err = listenerWrapper(listener) if err != nil { return err } } logger.Info(logSender, "", "server listener registered, address: %s TLS enabled: %t", listener.Addr().String(), isTLS) defer listener.Close() if isTLS { return srv.ServeTLS(listener, "", "") } return srv.Serve(listener) } // GetTLSCiphersFromNames returns the TLS ciphers from the specified names func GetTLSCiphersFromNames(cipherNames []string) []uint16 { var ciphers []uint16 for _, name := range RemoveDuplicates(cipherNames, false) { for _, c := range tls.CipherSuites() { if c.Name == strings.TrimSpace(name) { ciphers = append(ciphers, c.ID) } } for _, c := range tls.InsecureCipherSuites() { if c.Name == strings.TrimSpace(name) { ciphers = append(ciphers, c.ID) } } } if len(ciphers) == 0 { // return a secure default return defaultTLSCiphers } return ciphers } // GetALPNProtocols returns the ALPN protocols, any invalid protocol will be // silently ignored. If no protocol or no valid protocol is provided the default // is http/1.1, h2 func GetALPNProtocols(protocols []string) []string { var result []string for _, p := range protocols { switch p { case "http/1.1", "h2": result = append(result, p) } } if len(result) == 0 { return []string{"http/1.1", "h2"} } return result } // EncodeTLSCertToPem returns the specified certificate PEM encoded. // This can be verified using openssl x509 -in cert.crt -text -noout func EncodeTLSCertToPem(tlsCert *x509.Certificate) (string, error) { if len(tlsCert.Raw) == 0 { return "", errors.New("invalid x509 certificate, no der contents") } publicKeyBlock := pem.Block{ Type: "CERTIFICATE", Bytes: tlsCert.Raw, } return BytesToString(pem.EncodeToMemory(&publicKeyBlock)), nil } // CheckTCP4Port quits the app if bind on the given IPv4 port fails. // This is a ugly hack to avoid to bind on an already used port. // It is required on Windows only. Upstream does not consider this // behaviour a bug: // https://github.com/golang/go/issues/45150 func CheckTCP4Port(port int) { if runtime.GOOS != osWindows { return } listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) if err != nil { logger.ErrorToConsole("unable to bind on tcp4 address: %v", err) logger.Error(logSender, "", "unable to bind on tcp4 address: %v", err) os.Exit(1) } listener.Close() } // IsByteArrayEmpty return true if the byte array is empty or a new line func IsByteArrayEmpty(b []byte) bool { if len(b) == 0 { return true } if bytes.Equal(b, []byte("\n")) { return true } if bytes.Equal(b, []byte("\r\n")) { return true } return false } // GetSSHPublicKeyAsString returns an SSH public key serialized as string func GetSSHPublicKeyAsString(pubKey []byte) (string, error) { if len(pubKey) == 0 { return "", nil } k, err := ssh.ParsePublicKey(pubKey) if err != nil { return "", err } return BytesToString(ssh.MarshalAuthorizedKey(k)), nil } // GetRealIP returns the ip address as result of parsing the specified // header and using the specified depth func GetRealIP(r *http.Request, header string, depth int) string { if header == "" { return "" } var ipAddresses []string for _, h := range r.Header.Values(header) { for ipStr := range strings.SplitSeq(h, ",") { ipStr = strings.TrimSpace(ipStr) ipAddresses = append(ipAddresses, ipStr) } } idx := len(ipAddresses) - 1 - depth if idx >= 0 { ip := strings.TrimSpace(ipAddresses[idx]) if ip == "" || net.ParseIP(ip) == nil { return "" } return ip } return "" } // GetHTTPLocalAddress returns the local address for an http.Request // or empty if it cannot be determined func GetHTTPLocalAddress(r *http.Request) string { if r == nil { return "" } localAddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr) if ok { return localAddr.String() } return "" } // ParseAllowedIPAndRanges returns a list of functions that allow to find if an // IP is equal or is contained within the allowed list func ParseAllowedIPAndRanges(allowed []string) ([]func(net.IP) bool, error) { res := make([]func(net.IP) bool, len(allowed)) for i, allowFrom := range allowed { if strings.LastIndex(allowFrom, "/") > 0 { _, ipRange, err := net.ParseCIDR(allowFrom) if err != nil { return nil, fmt.Errorf("given string %q is not a valid IP range: %v", allowFrom, err) } res[i] = ipRange.Contains } else { allowed := net.ParseIP(allowFrom) if allowed == nil { return nil, fmt.Errorf("given string %q is not a valid IP address", allowFrom) } res[i] = allowed.Equal } } return res, nil } // GetRedactedURL returns the url redacting the password if any func GetRedactedURL(rawurl string) string { if !strings.HasPrefix(rawurl, "http") { return rawurl } u, err := url.Parse(rawurl) if err != nil { return rawurl } return u.Redacted() } // GetTLSVersion returns the TLS version from an integer value: // - 10 means TLS 1.0 // - 11 means TLS 1.1 // - 12 means TLS 1.2 // - 13 means TLS 1.3 // default is TLS 1.2 func GetTLSVersion(val int) uint16 { switch val { case 13: return tls.VersionTLS13 case 11: return tls.VersionTLS11 case 10: return tls.VersionTLS10 default: return tls.VersionTLS12 } } // IsEmailValid returns true if the specified email address is valid func IsEmailValid(email string) bool { return emailRegex.MatchString(email) } // SanitizeDomain return the specified domain name in a form suitable to save as file func SanitizeDomain(domain string) string { return strings.NewReplacer(":", "_", "*", "_", ",", "_", " ", "_").Replace(domain) } // PanicOnError calls panic if err is not nil func PanicOnError(err error) { if err != nil { panic(fmt.Errorf("unexpected error: %w", err)) } } // GetAbsolutePath returns an absolute path using the current dir as base // if name defines a relative path func GetAbsolutePath(name string) (string, error) { if name == "" { return name, errors.New("input path cannot be empty") } if filepath.IsAbs(name) { return name, nil } curDir, err := os.Getwd() if err != nil { return name, err } return filepath.Join(curDir, name), nil } // GetACMECertificateKeyPair returns the path to the ACME TLS crt and key for the specified domain func GetACMECertificateKeyPair(domain string) (string, string) { if CertsBasePath == "" { return "", "" } domain = SanitizeDomain(domain) return filepath.Join(CertsBasePath, domain+".crt"), filepath.Join(CertsBasePath, domain+".key") } // GetLastIPForPrefix returns the last IP for the given prefix // https://github.com/go4org/netipx/blob/8449b0a6169f5140fb0340cb4fc0de4c9b281ef6/netipx.go#L173 func GetLastIPForPrefix(p netip.Prefix) netip.Addr { if !p.IsValid() { return netip.Addr{} } a16 := p.Addr().As16() var off uint8 var bits uint8 = 128 if p.Addr().Is4() { off = 12 bits = 32 } for b := uint8(p.Bits()); b < bits; b++ { byteNum, bitInByte := b/8, 7-(b%8) a16[off+byteNum] |= 1 << uint(bitInByte) } if p.Addr().Is4() { return netip.AddrFrom16(a16).Unmap() } return netip.AddrFrom16(a16) // doesn't unmap } // JSONEscape returns the JSON escaped format for the input string func JSONEscape(val string) string { if val == "" { return val } b, err := json.Marshal(val) if err != nil { return "" } return BytesToString(b[1 : len(b)-1]) } // ReadConfigFromFile reads a configuration parameter from the specified file func ReadConfigFromFile(name, configDir string) (string, error) { if !IsFileInputValid(name) { return "", fmt.Errorf("invalid file input: %q", name) } if configDir == "" { if !filepath.IsAbs(name) { return "", fmt.Errorf("%q must be an absolute file path", name) } } else { if name != "" && !filepath.IsAbs(name) { name = filepath.Join(configDir, name) } } val, err := os.ReadFile(name) if err != nil { return "", err } return strings.TrimSpace(BytesToString(val)), nil } // SlicesEqual checks if the provided slices contain the same elements, // also in different order. func SlicesEqual(s1, s2 []string) bool { if len(s1) != len(s2) { return false } for _, v := range s1 { if !slices.Contains(s2, v) { return false } } return true } // VerifyFileChecksum computes the hash of the given file using the provided // hash algorithm and compares it against the expected checksum (in hex format). // It returns an error if the checksum does not match or if the operation fails. func VerifyFileChecksum(filePath string, h hash.Hash, expectedHex string, maxSize int64) error { expected, err := hex.DecodeString(expectedHex) if err != nil { return fmt.Errorf("invalid checksum %q: %w", expectedHex, err) } f, err := os.Open(filePath) if err != nil { return err } defer f.Close() if maxSize > 0 { fi, err := f.Stat() if err != nil { return err } if fi.Size() > maxSize { return fmt.Errorf("file too large: %s", ByteCountIEC(fi.Size())) } } if _, err := io.Copy(h, f); err != nil { return err } actual := h.Sum(nil) if subtle.ConstantTimeCompare(actual, expected) != 1 { return errors.New("checksum mismatch") } return nil } ================================================ FILE: internal/util/util_fallback.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !unix package util import ( "runtime" "github.com/drakkan/sftpgo/v2/internal/logger" ) // SetUmask sets the specified umask func SetUmask(val string) { if val == "" { return } logger.Debug(logSender, "", "umask not supported on OS %q", runtime.GOOS) } ================================================ FILE: internal/util/util_unix.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build unix package util import ( "strconv" "syscall" "github.com/drakkan/sftpgo/v2/internal/logger" ) // SetUmask sets the specified umask func SetUmask(val string) { if val == "" { return } umask, err := strconv.ParseUint(val, 8, 31) if err != nil { logger.Error(logSender, "", "invalid umask %q: %v", val, err) return } logger.Debug(logSender, "", "set umask to: %d, configured value: %q", umask, val) syscall.Umask(int(umask)) } ================================================ FILE: internal/version/version.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package version defines SFTPGo version details package version import "strings" const ( version = "2.7.99-dev" appName = "SFTPGo" ) var ( commit = "" date = "" info Info ) var ( config string ) // Info defines version details type Info struct { Version string `json:"version"` BuildDate string `json:"build_date"` CommitHash string `json:"commit_hash"` Features []string `json:"features"` } // GetAsString returns the string representation of the version func GetAsString() string { var sb strings.Builder sb.WriteString(info.Version) if info.CommitHash != "" { sb.WriteString("-") sb.WriteString(info.CommitHash) } if info.BuildDate != "" { sb.WriteString("-") sb.WriteString(info.BuildDate) } if len(info.Features) > 0 { sb.WriteString(" ") sb.WriteString(strings.Join(info.Features, " ")) } return sb.String() } func init() { info = Info{ Version: version, CommitHash: commit, BuildDate: date, } } // AddFeature adds a feature description func AddFeature(feature string) { info.Features = append(info.Features, feature) } // Get returns the Info struct func Get() Info { return info } // SetConfig sets the version configuration func SetConfig(val string) { config = val } // GetServerVersion returns the server version according to the configuration // and the provided parameters. func GetServerVersion(separator string, addHash bool) string { var sb strings.Builder sb.WriteString(appName) if config != "short" { sb.WriteString(separator) sb.WriteString(info.Version) } if addHash { sb.WriteString(separator) sb.WriteString(info.CommitHash) } return sb.String() } // GetVersionHash returns the server identification string with the commit hash. func GetVersionHash() string { var sb strings.Builder sb.WriteString(appName) sb.WriteString("-") sb.WriteString(info.CommitHash) return sb.String() } ================================================ FILE: internal/vfs/azblobfs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !noazblob package vfs import ( "bytes" "context" "encoding/base64" "errors" "fmt" "io" "mime" "net/http" "os" "path" "path/filepath" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" "github.com/google/uuid" "github.com/pkg/sftp" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( azureDefaultEndpoint = "blob.core.windows.net" azFolderKey = "hdi_isfolder" ) var ( azureBlobDefaultPageSize = int32(5000) ) // AzureBlobFs is a Fs implementation for Azure Blob storage. type AzureBlobFs struct { connectionID string localTempDir string // if not empty this fs is mouted as virtual folder in the specified path mountPath string config *AzBlobFsConfig containerClient *container.Client ctxTimeout time.Duration ctxLongTimeout time.Duration } func init() { version.AddFeature("+azblob") } // NewAzBlobFs returns an AzBlobFs object that allows to interact with Azure Blob storage func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsConfig) (Fs, error) { if localTempDir == "" { localTempDir = getLocalTempDir() } fs := &AzureBlobFs{ connectionID: connectionID, localTempDir: localTempDir, mountPath: getMountPath(mountPath), config: &config, ctxTimeout: 30 * time.Second, ctxLongTimeout: 90 * time.Second, } if err := fs.config.validate(); err != nil { return fs, err } if err := fs.config.tryDecrypt(); err != nil { return fs, err } fs.setConfigDefaults() if fs.config.SASURL.GetPayload() != "" { return fs.initFromSASURL() } var endpoint string if fs.config.UseEmulator { endpoint = fmt.Sprintf("%s/%s", fs.config.Endpoint, fs.config.AccountName) } else { endpoint = fmt.Sprintf("https://%s.%s/", fs.config.AccountName, fs.config.Endpoint) } containerURL := runtime.JoinPaths(endpoint, fs.config.Container) if fs.config.AccountKey.GetPayload() != "" { credential, err := blob.NewSharedKeyCredential(fs.config.AccountName, fs.config.AccountKey.GetPayload()) if err != nil { return fs, fmt.Errorf("invalid credentials: %v", err) } svc, err := container.NewClientWithSharedKeyCredential(containerURL, credential, getAzContainerClientOptions()) if err != nil { return fs, fmt.Errorf("unable to create the storage client using shared key credentials: %v", err) } fs.containerClient = svc return fs, err } credential, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return fs, fmt.Errorf("invalid default azure credentials: %v", err) } svc, err := container.NewClient(containerURL, credential, getAzContainerClientOptions()) if err != nil { return fs, fmt.Errorf("unable to create the storage client using azure credentials: %v", err) } fs.containerClient = svc return fs, err } func (fs *AzureBlobFs) initFromSASURL() (Fs, error) { parts, err := blob.ParseURL(fs.config.SASURL.GetPayload()) if err != nil { return fs, fmt.Errorf("invalid SAS URL: %w", err) } if parts.BlobName != "" { return fs, fmt.Errorf("SAS URL with blob name not supported") } if parts.ContainerName != "" { if fs.config.Container != "" && fs.config.Container != parts.ContainerName { return fs, fmt.Errorf("container name in SAS URL %q and container provided %q do not match", parts.ContainerName, fs.config.Container) } svc, err := container.NewClientWithNoCredential(fs.config.SASURL.GetPayload(), getAzContainerClientOptions()) if err != nil { return fs, fmt.Errorf("invalid credentials: %v", err) } fs.config.Container = parts.ContainerName fs.containerClient = svc return fs, nil } if fs.config.Container == "" { return fs, errors.New("container is required with this SAS URL") } sasURL := runtime.JoinPaths(fs.config.SASURL.GetPayload(), fs.config.Container) svc, err := container.NewClientWithNoCredential(sasURL, getAzContainerClientOptions()) if err != nil { return fs, fmt.Errorf("invalid credentials: %v", err) } fs.containerClient = svc return fs, nil } // Name returns the name for the Fs implementation func (fs *AzureBlobFs) Name() string { if !fs.config.SASURL.IsEmpty() { return fmt.Sprintf("%s with SAS URL, container %q", azBlobFsName, fs.config.Container) } return fmt.Sprintf("%s container %q", azBlobFsName, fs.config.Container) } // ConnectionID returns the connection ID associated to this Fs implementation func (fs *AzureBlobFs) ConnectionID() string { return fs.connectionID } // Stat returns a FileInfo describing the named file func (fs *AzureBlobFs) Stat(name string) (os.FileInfo, error) { if name == "" || name == "/" || name == "." { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } if fs.config.KeyPrefix == name+"/" { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } attrs, err := fs.headObject(name) if err == nil { contentType := util.GetStringFromPointer(attrs.ContentType) isDir := checkDirectoryMarkers(contentType, attrs.Metadata) lastModified := util.GetTimeFromPointer(attrs.LastModified) if val := getAzureLastModified(attrs.Metadata); val > 0 { lastModified = util.GetTimeFromMsecSinceEpoch(val) } info := NewFileInfo(name, isDir, util.GetIntFromPointer(attrs.ContentLength), lastModified, false) if !isDir { info.setMetadataFromPointerVal(attrs.Metadata) } return info, nil } if !fs.IsNotExist(err) { return nil, err } // now check if this is a prefix (virtual directory) hasContents, err := fs.hasContents(name) if err != nil { return nil, err } if hasContents { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } return nil, os.ErrNotExist } // Lstat returns a FileInfo describing the named file func (fs *AzureBlobFs) Lstat(name string) (os.FileInfo, error) { return fs.Stat(name) } // Open opens the named file for reading func (fs *AzureBlobFs) Open(name string, offset int64) (File, PipeReader, func(), error) { r, w, err := createPipeFn(fs.localTempDir, fs.config.DownloadPartSize*int64(fs.config.DownloadConcurrency)+1) if err != nil { return nil, nil, nil, err } p := NewPipeReader(r) ctx, cancelFn := context.WithCancel(context.Background()) go func() { defer cancelFn() blockBlob := fs.containerClient.NewBlockBlobClient(name) err := fs.handleMultipartDownload(ctx, blockBlob, offset, w, p) w.CloseWithError(err) //nolint:errcheck fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %+v", name, w.GetWrittenBytes(), err) metric.AZTransferCompleted(w.GetWrittenBytes(), 1, err) }() return nil, p, cancelFn, nil } // Create creates or opens the named file for writing func (fs *AzureBlobFs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) { if checks&CheckParentDir != 0 { _, err := fs.Stat(path.Dir(name)) if err != nil { return nil, nil, nil, err } } r, w, err := createPipeFn(fs.localTempDir, fs.config.UploadPartSize+1024*1024) if err != nil { return nil, nil, nil, err } ctx, cancelFn := context.WithCancel(context.Background()) var p PipeWriter if checks&CheckResume != 0 { p = newPipeWriterAtOffset(w, 0) } else { p = NewPipeWriter(w) } headers := blob.HTTPHeaders{} var contentType string var metadata map[string]*string if flag == -1 { contentType = dirMimeType metadata = map[string]*string{ azFolderKey: util.NilIfEmpty("true"), } } else { contentType = mime.TypeByExtension(path.Ext(name)) } if contentType != "" { headers.BlobContentType = &contentType } go func() { defer cancelFn() blockBlob := fs.containerClient.NewBlockBlobClient(name) err := fs.handleMultipartUpload(ctx, r, blockBlob, &headers, metadata) r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %v, err: %+v", name, r.GetReadedBytes(), err) metric.AZTransferCompleted(r.GetReadedBytes(), 0, err) }() if checks&CheckResume != 0 { readCh := make(chan error, 1) go func() { n, err := fs.downloadToWriter(name, p) pw := p.(*pipeWriterAtOffset) pw.offset = 0 pw.writeOffset = n readCh <- err }() err = <-readCh if err != nil { cancelFn() p.Close() fsLog(fs, logger.LevelDebug, "download before resume failed, writer closed and read cancelled") return nil, nil, nil, err } } if uploadMode&16 != 0 { return nil, p, nil, nil } return nil, p, cancelFn, nil } // Rename renames (moves) source to target. func (fs *AzureBlobFs) Rename(source, target string, checks int) (int, int64, error) { if source == target { return -1, -1, nil } if checks&CheckParentDir != 0 { _, err := fs.Stat(path.Dir(target)) if err != nil { return -1, -1, err } } fi, err := fs.Stat(source) if err != nil { return -1, -1, err } return fs.renameInternal(source, target, fi, 0, checks&CheckUpdateModTime != 0) } // Remove removes the named file or (empty) directory. func (fs *AzureBlobFs) Remove(name string, isDir bool) error { if isDir { hasContents, err := fs.hasContents(name) if err != nil { return err } if hasContents { return fmt.Errorf("cannot remove non empty directory: %q", name) } } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() blobBlock := fs.containerClient.NewBlockBlobClient(name) var deletSnapshots blob.DeleteSnapshotsOptionType if !isDir { deletSnapshots = blob.DeleteSnapshotsOptionTypeInclude } _, err := blobBlock.Delete(ctx, &blob.DeleteOptions{ DeleteSnapshots: &deletSnapshots, }) if err != nil && isDir { if fs.isBadRequestError(err) { deletSnapshots = blob.DeleteSnapshotsOptionTypeInclude _, err = blobBlock.Delete(ctx, &blob.DeleteOptions{ DeleteSnapshots: &deletSnapshots, }) } } metric.AZDeleteObjectCompleted(err) return err } // Mkdir creates a new directory with the specified name and default permissions func (fs *AzureBlobFs) Mkdir(name string) error { _, err := fs.Stat(name) if !fs.IsNotExist(err) { return err } return fs.mkdirInternal(name) } // Symlink creates source as a symbolic link to target. func (*AzureBlobFs) Symlink(_, _ string) error { return ErrVfsUnsupported } // Readlink returns the destination of the named symbolic link func (*AzureBlobFs) Readlink(_ string) (string, error) { return "", ErrVfsUnsupported } // Chown changes the numeric uid and gid of the named file. func (*AzureBlobFs) Chown(_ string, _ int, _ int) error { return ErrVfsUnsupported } // Chmod changes the mode of the named file to mode. func (*AzureBlobFs) Chmod(_ string, _ os.FileMode) error { return ErrVfsUnsupported } // Chtimes changes the access and modification times of the named file. func (fs *AzureBlobFs) Chtimes(name string, _, mtime time.Time, isUploading bool) error { if isUploading { return nil } props, err := fs.headObject(name) if err != nil { return err } metadata := props.Metadata if metadata == nil { metadata = make(map[string]*string) } found := false for k := range metadata { if strings.EqualFold(k, lastModifiedField) { metadata[k] = to.Ptr(strconv.FormatInt(mtime.UnixMilli(), 10)) found = true break } } if !found { metadata[lastModifiedField] = to.Ptr(strconv.FormatInt(mtime.UnixMilli(), 10)) } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() _, err = fs.containerClient.NewBlockBlobClient(name).SetMetadata(ctx, metadata, &blob.SetMetadataOptions{}) return err } // Truncate changes the size of the named file. // Truncate by path is not supported, while truncating an opened // file is handled inside base transfer func (*AzureBlobFs) Truncate(_ string, _ int64) error { return ErrVfsUnsupported } // ReadDir reads the directory named by dirname and returns // a list of directory entries. func (fs *AzureBlobFs) ReadDir(dirname string) (DirLister, error) { // dirname must be already cleaned prefix := fs.getPrefix(dirname) pager := fs.containerClient.NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{ Include: container.ListBlobsInclude{ Metadata: true, }, Prefix: &prefix, MaxResults: &azureBlobDefaultPageSize, }) return &azureBlobDirLister{ paginator: pager, timeout: fs.ctxTimeout, prefix: prefix, prefixes: make(map[string]bool), }, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. // Resuming uploads is not supported on Azure Blob func (*AzureBlobFs) IsUploadResumeSupported() bool { return false } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (*AzureBlobFs) IsConditionalUploadResumeSupported(size int64) bool { return size <= resumeMaxSize } // IsAtomicUploadSupported returns true if atomic upload is supported. // Azure Blob uploads are already atomic, we don't need to upload to a temporary // file func (*AzureBlobFs) IsAtomicUploadSupported() bool { return false } // IsNotExist returns a boolean indicating whether the error is known to // report that a file or directory does not exist func (*AzureBlobFs) IsNotExist(err error) bool { if err == nil { return false } var respErr *azcore.ResponseError if errors.As(err, &respErr) { return respErr.StatusCode == http.StatusNotFound } // os.ErrNotExist can be returned internally by fs.Stat return errors.Is(err, os.ErrNotExist) } // IsPermission returns a boolean indicating whether the error is known to // report that permission is denied. func (*AzureBlobFs) IsPermission(err error) bool { if err == nil { return false } var respErr *azcore.ResponseError if errors.As(err, &respErr) { return respErr.StatusCode == http.StatusForbidden || respErr.StatusCode == http.StatusUnauthorized } return false } // IsNotSupported returns true if the error indicate an unsupported operation func (*AzureBlobFs) IsNotSupported(err error) bool { if err == nil { return false } return errors.Is(err, ErrVfsUnsupported) } func (*AzureBlobFs) isBadRequestError(err error) bool { if err == nil { return false } var respErr *azcore.ResponseError if errors.As(err, &respErr) { return respErr.StatusCode == http.StatusBadRequest } return false } // CheckRootPath creates the specified local root directory if it does not exists func (fs *AzureBlobFs) CheckRootPath(username string, uid int, gid int) bool { // we need a local directory for temporary files osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) return osFs.CheckRootPath(username, uid, gid) } // ScanRootDirContents returns the number of files contained in the bucket, // and their size func (fs *AzureBlobFs) ScanRootDirContents() (int, int64, error) { return fs.GetDirSize(fs.config.KeyPrefix) } // GetDirSize returns the number of files and the size for a folder // including any subfolders func (fs *AzureBlobFs) GetDirSize(dirname string) (int, int64, error) { numFiles := 0 size := int64(0) prefix := fs.getPrefix(dirname) pager := fs.containerClient.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ Include: container.ListBlobsInclude{ Metadata: true, }, Prefix: &prefix, MaxResults: &azureBlobDefaultPageSize, }) for pager.More() { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := pager.NextPage(ctx) if err != nil { metric.AZListObjectsCompleted(err) return numFiles, size, err } for _, blobItem := range resp.Segment.BlobItems { if blobItem.Properties != nil { contentType := util.GetStringFromPointer(blobItem.Properties.ContentType) isDir := checkDirectoryMarkers(contentType, blobItem.Metadata) blobSize := util.GetIntFromPointer(blobItem.Properties.ContentLength) if isDir && blobSize == 0 { continue } numFiles++ size += blobSize } } fsLog(fs, logger.LevelDebug, "scan in progress for %q, files: %d, size: %d", dirname, numFiles, size) } metric.AZListObjectsCompleted(nil) return numFiles, size, nil } // GetAtomicUploadPath returns the path to use for an atomic upload. // Azure Blob Storage uploads are already atomic, we never call this method func (*AzureBlobFs) GetAtomicUploadPath(_ string) string { return "" } // GetRelativePath returns the path for a file relative to the user's home dir. // This is the path as seen by SFTPGo users func (fs *AzureBlobFs) GetRelativePath(name string) string { rel := path.Clean(name) if rel == "." { rel = "" } if !path.IsAbs(rel) { rel = "/" + rel } if fs.config.KeyPrefix != "" { if !strings.HasPrefix(rel, "/"+fs.config.KeyPrefix) { rel = "/" } rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) } if fs.mountPath != "" { rel = path.Join(fs.mountPath, rel) } return rel } // Walk walks the file tree rooted at root, calling walkFn for each file or // directory in the tree, including root func (fs *AzureBlobFs) Walk(root string, walkFn filepath.WalkFunc) error { prefix := fs.getPrefix(root) pager := fs.containerClient.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ Include: container.ListBlobsInclude{ Metadata: true, }, Prefix: &prefix, MaxResults: &azureBlobDefaultPageSize, }) for pager.More() { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := pager.NextPage(ctx) if err != nil { metric.AZListObjectsCompleted(err) return err } for _, blobItem := range resp.Segment.BlobItems { name := util.GetStringFromPointer(blobItem.Name) if fs.isEqual(name, prefix) { continue } blobSize := int64(0) lastModified := time.Unix(0, 0) isDir := false if blobItem.Properties != nil { contentType := util.GetStringFromPointer(blobItem.Properties.ContentType) isDir = checkDirectoryMarkers(contentType, blobItem.Metadata) blobSize = util.GetIntFromPointer(blobItem.Properties.ContentLength) lastModified = util.GetTimeFromPointer(blobItem.Properties.LastModified) if val := getAzureLastModified(blobItem.Metadata); val > 0 { lastModified = util.GetTimeFromMsecSinceEpoch(val) } } err := walkFn(name, NewFileInfo(name, isDir, blobSize, lastModified, false), nil) if err != nil { return err } } } metric.AZListObjectsCompleted(nil) return walkFn(root, NewFileInfo(root, true, 0, time.Unix(0, 0), false), nil) } // Join joins any number of path elements into a single path func (*AzureBlobFs) Join(elem ...string) string { return strings.TrimPrefix(path.Join(elem...), "/") } // HasVirtualFolders returns true if folders are emulated func (*AzureBlobFs) HasVirtualFolders() bool { return true } // ResolvePath returns the matching filesystem path for the specified sftp path func (fs *AzureBlobFs) ResolvePath(virtualPath string) (string, error) { if fs.mountPath != "" { if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { virtualPath = after } } virtualPath = path.Clean("/" + virtualPath) return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil } // CopyFile implements the FsFileCopier interface func (fs *AzureBlobFs) CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) { numFiles := 1 sizeDiff := srcInfo.Size() attrs, err := fs.headObject(target) if err == nil { sizeDiff -= util.GetIntFromPointer(attrs.ContentLength) numFiles = 0 } else { if !fs.IsNotExist(err) { return 0, 0, err } } if err := fs.copyFileInternal(source, target, srcInfo, true); err != nil { return 0, 0, err } return numFiles, sizeDiff, nil } func (fs *AzureBlobFs) headObject(name string) (blob.GetPropertiesResponse, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := fs.containerClient.NewBlockBlobClient(name).GetProperties(ctx, &blob.GetPropertiesOptions{}) metric.AZHeadObjectCompleted(err) return resp, err } // GetMimeType returns the content type func (fs *AzureBlobFs) GetMimeType(name string) (string, error) { response, err := fs.headObject(name) if err != nil { return "", err } return util.GetStringFromPointer(response.ContentType), nil } // Close closes the fs func (*AzureBlobFs) Close() error { return nil } // GetAvailableDiskSize returns the available size for the specified path func (*AzureBlobFs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) { return nil, ErrStorageSizeUnavailable } func (*AzureBlobFs) getPrefix(name string) string { prefix := "" if name != "" && name != "." { prefix = strings.TrimPrefix(name, "/") if !strings.HasSuffix(prefix, "/") { prefix += "/" } } return prefix } func (fs *AzureBlobFs) isEqual(key string, virtualName string) bool { if key == virtualName { return true } if key == virtualName+"/" { return true } if key+"/" == virtualName { return true } return false } func (fs *AzureBlobFs) setConfigDefaults() { if fs.config.Endpoint == "" { fs.config.Endpoint = azureDefaultEndpoint } if fs.config.UploadPartSize == 0 { fs.config.UploadPartSize = 5 } if fs.config.UploadPartSize < 1024*1024 { fs.config.UploadPartSize *= 1024 * 1024 } if fs.config.UploadConcurrency == 0 { fs.config.UploadConcurrency = 5 } if fs.config.DownloadPartSize == 0 { fs.config.DownloadPartSize = 5 } if fs.config.DownloadPartSize < 1024*1024 { fs.config.DownloadPartSize *= 1024 * 1024 } if fs.config.DownloadConcurrency == 0 { fs.config.DownloadConcurrency = 5 } } func (fs *AzureBlobFs) copyFileInternal(source, target string, srcInfo os.FileInfo, updateModTime bool) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout)) defer cancelFn() srcBlob := fs.containerClient.NewBlockBlobClient(source) dstBlob := fs.containerClient.NewBlockBlobClient(target) resp, err := dstBlob.StartCopyFromURL(ctx, srcBlob.URL(), fs.getCopyOptions(srcInfo, updateModTime)) if err != nil { metric.AZCopyObjectCompleted(err) return err } copyStatus := blob.CopyStatusType(util.GetStringFromPointer((*string)(resp.CopyStatus))) nErrors := 0 for copyStatus == blob.CopyStatusTypePending { // Poll until the copy is complete. time.Sleep(500 * time.Millisecond) resp, err := dstBlob.GetProperties(ctx, &blob.GetPropertiesOptions{}) if err != nil { // A GetProperties failure may be transient, so allow a couple // of them before giving up. nErrors++ if ctx.Err() != nil || nErrors == 3 { metric.AZCopyObjectCompleted(err) return err } } else { copyStatus = blob.CopyStatusType(util.GetStringFromPointer((*string)(resp.CopyStatus))) } } if copyStatus != blob.CopyStatusTypeSuccess { err := fmt.Errorf("copy failed with status: %s", copyStatus) metric.AZCopyObjectCompleted(err) return err } metric.AZCopyObjectCompleted(nil) return nil } func (fs *AzureBlobFs) renameInternal(source, target string, srcInfo os.FileInfo, recursion int, updateModTime bool, ) (int, int64, error) { var numFiles int var filesSize int64 if srcInfo.IsDir() { if renameMode == 0 { hasContents, err := fs.hasContents(source) if err != nil { return numFiles, filesSize, err } if hasContents { return numFiles, filesSize, fmt.Errorf("%w: cannot rename non empty directory: %q", ErrVfsUnsupported, source) } } if err := fs.mkdirInternal(target); err != nil { return numFiles, filesSize, err } if renameMode == 1 { files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion, updateModTime) numFiles += files filesSize += size if err != nil { return numFiles, filesSize, err } } } else { if err := fs.copyFileInternal(source, target, srcInfo, updateModTime); err != nil { return numFiles, filesSize, err } numFiles++ filesSize += srcInfo.Size() } err := fs.skipNotExistErr(fs.Remove(source, srcInfo.IsDir())) return numFiles, filesSize, err } func (fs *AzureBlobFs) skipNotExistErr(err error) error { if fs.IsNotExist(err) { return nil } return err } func (fs *AzureBlobFs) mkdirInternal(name string) error { _, w, _, err := fs.Create(name, -1, 0) if err != nil { return err } return w.Close() } func (fs *AzureBlobFs) hasContents(name string) (bool, error) { result := false prefix := fs.getPrefix(name) maxResults := int32(1) pager := fs.containerClient.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ MaxResults: &maxResults, Prefix: &prefix, }) if pager.More() { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := pager.NextPage(ctx) if err != nil { metric.AZListObjectsCompleted(err) return result, err } result = len(resp.Segment.BlobItems) > 0 } metric.AZListObjectsCompleted(nil) return result, nil } func (fs *AzureBlobFs) downloadPart(ctx context.Context, blockBlob *blockblob.Client, buf []byte, w io.WriterAt, offset, count, writeOffset int64, ) error { if count == 0 { return nil } resp, err := blockBlob.DownloadStream(ctx, &blob.DownloadStreamOptions{ Range: blob.HTTPRange{ Offset: offset, Count: count, }, }) if err != nil { return err } defer resp.Body.Close() _, err = io.ReadAtLeast(resp.Body, buf, int(count)) if err != nil { return err } return writeAtFull(w, buf, writeOffset, int(count)) } func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *blockblob.Client, offset int64, writer io.WriterAt, pipeReader PipeReader, ) error { props, err := blockBlob.GetProperties(ctx, &blob.GetPropertiesOptions{}) metric.AZHeadObjectCompleted(err) if err != nil { fsLog(fs, logger.LevelError, "unable to get blob properties, download aborted: %+v", err) return err } if readMetadata > 0 && pipeReader != nil { pipeReader.setMetadataFromPointerVal(props.Metadata) } contentLength := util.GetIntFromPointer(props.ContentLength) sizeToDownload := contentLength - offset if sizeToDownload < 0 { fsLog(fs, logger.LevelError, "invalid multipart download size or offset, size: %v, offset: %v, size to download: %v", contentLength, offset, sizeToDownload) return errors.New("the requested offset exceeds the file size") } if sizeToDownload == 0 { fsLog(fs, logger.LevelDebug, "nothing to download, offset %v, content length %v", offset, contentLength) return nil } partSize := fs.config.DownloadPartSize guard := make(chan struct{}, fs.config.DownloadConcurrency) blockCtxTimeout := time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute pool := newBufferAllocator(int(partSize)) defer pool.free() finished := false var wg sync.WaitGroup var errOnce sync.Once var hasError atomic.Bool var poolError error poolCtx, poolCancel := context.WithCancel(ctx) defer poolCancel() for part := 0; !finished; part++ { start := offset end := offset + partSize if end >= contentLength { end = contentLength finished = true } writeOffset := int64(part) * partSize offset = end guard <- struct{}{} if hasError.Load() { fsLog(fs, logger.LevelDebug, "pool error, download for part %v not started", part) break } buf := pool.getBuffer() wg.Add(1) go func(start, end, writeOffset int64, buf []byte) { defer func() { pool.releaseBuffer(buf) <-guard wg.Done() }() innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout)) defer cancelFn() count := end - start err := fs.downloadPart(innerCtx, blockBlob, buf, writer, start, count, writeOffset) if err != nil { errOnce.Do(func() { fsLog(fs, logger.LevelError, "multipart download error: %+v", err) hasError.Store(true) poolError = fmt.Errorf("multipart download error: %w", err) poolCancel() }) } }(start, end, writeOffset, buf) } wg.Wait() close(guard) return poolError } func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Reader, blockBlob *blockblob.Client, httpHeaders *blob.HTTPHeaders, metadata map[string]*string, ) error { partSize := fs.config.UploadPartSize guard := make(chan struct{}, fs.config.UploadConcurrency) blockCtxTimeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute // sync.Pool seems to use a lot of memory so prefer our own, very simple, allocator // we only need to recycle few byte slices pool := newBufferAllocator(int(partSize)) defer pool.free() finished := false var blocks []string var wg sync.WaitGroup var errOnce sync.Once var hasError atomic.Bool var poolError error poolCtx, poolCancel := context.WithCancel(ctx) defer poolCancel() finalizeFailedUpload := func(err error) { fsLog(fs, logger.LevelDebug, "multipart upload error: %+v", err) hasError.Store(true) poolError = fmt.Errorf("multipart upload error: %w", err) poolCancel() } for part := 0; !finished; part++ { buf := pool.getBuffer() n, err := readFill(reader, buf) if err == io.EOF { // read finished, if n > 0 we need to process the last data chunck if n == 0 { pool.releaseBuffer(buf) break } finished = true } else if err != nil { pool.releaseBuffer(buf) errOnce.Do(func() { finalizeFailedUpload(err) }) break } // Block IDs are unique values to avoid issue if 2+ clients are uploading blocks // at the same time causing CommitBlockList to get a mix of blocks from all the clients. generatedUUID, err := uuid.NewRandom() if err != nil { pool.releaseBuffer(buf) errOnce.Do(func() { finalizeFailedUpload(err) }) break } blockID := base64.StdEncoding.EncodeToString([]byte(generatedUUID.String())) blocks = append(blocks, blockID) guard <- struct{}{} if hasError.Load() { fsLog(fs, logger.LevelError, "pool error, upload for part %d not started", part) pool.releaseBuffer(buf) break } wg.Add(1) go func(blockID string, buf []byte, bufSize int) { defer func() { pool.releaseBuffer(buf) <-guard wg.Done() }() bufferReader := &bytesReaderWrapper{ Reader: bytes.NewReader(buf[:bufSize]), } innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout)) defer cancelFn() _, err := blockBlob.StageBlock(innerCtx, blockID, bufferReader, &blockblob.StageBlockOptions{}) if err != nil { errOnce.Do(func() { fsLog(fs, logger.LevelDebug, "multipart upload error: %+v", err) finalizeFailedUpload(err) }) } }(blockID, buf, n) } wg.Wait() close(guard) if poolError != nil { return poolError } commitOptions := blockblob.CommitBlockListOptions{ HTTPHeaders: httpHeaders, Metadata: metadata, } if fs.config.AccessTier != "" { commitOptions.Tier = (*blob.AccessTier)(&fs.config.AccessTier) } _, err := blockBlob.CommitBlockList(ctx, blocks, &commitOptions) return err } func (fs *AzureBlobFs) getCopyOptions(srcInfo os.FileInfo, updateModTime bool) *blob.StartCopyFromURLOptions { copyOptions := &blob.StartCopyFromURLOptions{} if fs.config.AccessTier != "" { copyOptions.Tier = (*blob.AccessTier)(&fs.config.AccessTier) } if updateModTime { metadata := make(map[string]*string) for k, v := range getMetadata(srcInfo) { if v != "" { if strings.EqualFold(k, lastModifiedField) { metadata[k] = to.Ptr("0") } else { metadata[k] = to.Ptr(v) } } } if len(metadata) > 0 { copyOptions.Metadata = metadata } } return copyOptions } func (fs *AzureBlobFs) downloadToWriter(name string, w PipeWriter) (int64, error) { fsLog(fs, logger.LevelDebug, "starting download before resuming upload, path %q", name) ctx, cancelFn := context.WithTimeout(context.Background(), preResumeTimeout) defer cancelFn() blockBlob := fs.containerClient.NewBlockBlobClient(name) err := fs.handleMultipartDownload(ctx, blockBlob, 0, w, nil) n := w.GetWrittenBytes() fsLog(fs, logger.LevelDebug, "download before resuming upload completed, path %q size: %d, err: %+v", name, n, err) metric.AZTransferCompleted(n, 1, err) return n, err } func checkDirectoryMarkers(contentType string, metadata map[string]*string) bool { if contentType == dirMimeType { return true } for k, v := range metadata { if strings.EqualFold(k, azFolderKey) { return strings.EqualFold(util.GetStringFromPointer(v), "true") } } return false } func getAzContainerClientOptions() *container.ClientOptions { return &container.ClientOptions{ ClientOptions: azcore.ClientOptions{ Telemetry: policy.TelemetryOptions{ ApplicationID: version.GetVersionHash(), }, }, } } type azureBlobDirLister struct { baseDirLister paginator *runtime.Pager[container.ListBlobsHierarchyResponse] timeout time.Duration prefix string prefixes map[string]bool metricUpdated bool } func (l *azureBlobDirLister) Next(limit int) ([]os.FileInfo, error) { if limit <= 0 { return nil, errInvalidDirListerLimit } if len(l.cache) >= limit { return l.returnFromCache(limit), nil } if !l.paginator.More() { if !l.metricUpdated { l.metricUpdated = true metric.AZListObjectsCompleted(nil) } return l.returnFromCache(limit), io.EOF } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) defer cancelFn() page, err := l.paginator.NextPage(ctx) if err != nil { metric.AZListObjectsCompleted(err) return l.cache, err } for _, blobPrefix := range page.Segment.BlobPrefixes { name := util.GetStringFromPointer(blobPrefix.Name) // we don't support prefixes == "/" this will be sent if a key starts with "/" if name == "" || name == "/" { continue } // sometime we have duplicate prefixes, maybe an Azurite bug name = strings.TrimPrefix(name, l.prefix) if _, ok := l.prefixes[strings.TrimSuffix(name, "/")]; ok { continue } l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) l.prefixes[strings.TrimSuffix(name, "/")] = true } for _, blobItem := range page.Segment.BlobItems { name := util.GetStringFromPointer(blobItem.Name) name = strings.TrimPrefix(name, l.prefix) size := int64(0) isDir := false var metadata map[string]*string modTime := time.Unix(0, 0) if blobItem.Properties != nil { size = util.GetIntFromPointer(blobItem.Properties.ContentLength) modTime = util.GetTimeFromPointer(blobItem.Properties.LastModified) contentType := util.GetStringFromPointer(blobItem.Properties.ContentType) isDir = checkDirectoryMarkers(contentType, blobItem.Metadata) if isDir { // check if the dir is already included, it will be sent as blob prefix if it contains at least one item if _, ok := l.prefixes[name]; ok { continue } l.prefixes[name] = true } else { metadata = blobItem.Metadata } if val := getAzureLastModified(blobItem.Metadata); val > 0 { modTime = util.GetTimeFromMsecSinceEpoch(val) } } info := NewFileInfo(name, isDir, size, modTime, false) info.setMetadataFromPointerVal(metadata) l.cache = append(l.cache, info) } return l.returnFromCache(limit), nil } func (l *azureBlobDirLister) Close() error { clear(l.prefixes) return l.baseDirLister.Close() } ================================================ FILE: internal/vfs/azblobfs_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build noazblob package vfs import ( "errors" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-azblob") } // NewAzBlobFs returns an error, Azure Blob storage is disabled func NewAzBlobFs(_, _, _ string, _ AzBlobFsConfig) (Fs, error) { return nil, errors.New("Azure Blob Storage disabled at build time") } ================================================ FILE: internal/vfs/cryptfs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package vfs import ( "bufio" "bytes" "crypto/rand" "crypto/sha256" "fmt" "io" "net/http" "os" "github.com/minio/sio" "golang.org/x/crypto/hkdf" "github.com/drakkan/sftpgo/v2/internal/logger" ) const ( // cryptFsName is the name for the local Fs implementation with encryption support cryptFsName = "cryptfs" version10 byte = 0x10 nonceV10Size int = 32 headerV10Size int64 = 33 // 1 (version byte) + 32 (nonce size) ) // CryptFs is a Fs implementation that allows to encrypts/decrypts local files type CryptFs struct { *OsFs localTempDir string masterKey []byte } // NewCryptFs returns a CryptFs object func NewCryptFs(connectionID, rootDir, mountPath string, config CryptFsConfig) (Fs, error) { if err := config.validate(); err != nil { return nil, err } if err := config.Passphrase.TryDecrypt(); err != nil { return nil, err } fs := &CryptFs{ OsFs: &OsFs{ name: cryptFsName, connectionID: connectionID, rootDir: rootDir, mountPath: getMountPath(mountPath), readBufferSize: config.ReadBufferSize * 1024 * 1024, writeBufferSize: config.WriteBufferSize * 1024 * 1024, }, masterKey: []byte(config.Passphrase.GetPayload()), } if tempPath == "" { fs.localTempDir = rootDir } else { fs.localTempDir = tempPath } return fs, nil } // Name returns the name for the Fs implementation func (fs *CryptFs) Name() string { return fs.name } // Open opens the named file for reading func (fs *CryptFs) Open(name string, offset int64) (File, PipeReader, func(), error) { f, key, err := fs.getFileAndEncryptionKey(name) if err != nil { return nil, nil, nil, err } isZeroDownload, err := isZeroBytesDownload(f, offset) if err != nil { f.Close() return nil, nil, nil, err } r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { f.Close() return nil, nil, nil, err } p := NewPipeReader(r) go func() { if isZeroDownload { w.CloseWithError(err) //nolint:errcheck f.Close() fsLog(fs, logger.LevelDebug, "zero bytes download completed, path: %q", name) return } var n int64 var err error if offset == 0 { n, err = fs.decryptWrapper(w, f, fs.getSIOConfig(key)) } else { var readerAt io.ReaderAt var readed, written int buf := make([]byte, 65568) wrapper := &cryptedFileWrapper{ File: f, } readerAt, err = sio.DecryptReaderAt(wrapper, fs.getSIOConfig(key)) if err == nil { finished := false for !finished { readed, err = readerAt.ReadAt(buf, offset) offset += int64(readed) if err != nil && err != io.EOF { break } if err == io.EOF { finished = true err = nil } if readed > 0 { written, err = w.Write(buf[:readed]) n += int64(written) if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } break } if readed != written { err = io.ErrShortWrite break } } } } } w.CloseWithError(err) //nolint:errcheck f.Close() fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %v", name, n, err) }() return nil, p, nil, nil } // Create creates or opens the named file for writing func (fs *CryptFs) Create(name string, _, _ int) (File, PipeWriter, func(), error) { f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) if err != nil { return nil, nil, nil, err } header := encryptedFileHeader{ version: version10, nonce: make([]byte, 32), } _, err = io.ReadFull(rand.Reader, header.nonce) if err != nil { f.Close() return nil, nil, nil, err } var key [32]byte kdf := hkdf.New(sha256.New, fs.masterKey, header.nonce, nil) _, err = io.ReadFull(kdf, key[:]) if err != nil { f.Close() return nil, nil, nil, err } r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { f.Close() return nil, nil, nil, err } err = header.Store(f) if err != nil { r.Close() w.Close() f.Close() return nil, nil, nil, err } p := NewPipeWriter(w) go func() { var n int64 var err error if fs.writeBufferSize <= 0 { n, err = sio.Encrypt(f, r, fs.getSIOConfig(key)) } else { bw := bufio.NewWriterSize(f, fs.writeBufferSize) n, err = fs.encryptWrapper(bw, r, fs.getSIOConfig(key)) errFlush := bw.Flush() if err == nil && errFlush != nil { err = errFlush } } errClose := f.Close() if err == nil && errClose != nil { err = errClose } r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %v, err: %v", name, n, err) }() return nil, p, nil, nil } // Truncate changes the size of the named file func (*CryptFs) Truncate(_ string, _ int64) error { return ErrVfsUnsupported } // ReadDir reads the directory named by dirname and returns // a list of directory entries. func (fs *CryptFs) ReadDir(dirname string) (DirLister, error) { f, err := os.Open(dirname) if err != nil { if isInvalidNameError(err) { err = os.ErrNotExist } return nil, err } return &cryptFsDirLister{f}, nil } // IsUploadResumeSupported returns false sio does not support random access writes func (*CryptFs) IsUploadResumeSupported() bool { return false } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (*CryptFs) IsConditionalUploadResumeSupported(_ int64) bool { return false } // GetMimeType returns the content type func (fs *CryptFs) GetMimeType(name string) (string, error) { f, key, err := fs.getFileAndEncryptionKey(name) if err != nil { return "", err } defer f.Close() readSize, err := sio.DecryptedSize(512) if err != nil { return "", err } buf := make([]byte, readSize) n, err := io.ReadFull(f, buf) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { return "", err } decrypted := bytes.NewBuffer(nil) _, err = sio.Decrypt(decrypted, bytes.NewBuffer(buf[:n]), fs.getSIOConfig(key)) if err != nil { return "", err } ctype := http.DetectContentType(decrypted.Bytes()) // Rewind file. _, err = f.Seek(0, io.SeekStart) return ctype, err } func (fs *CryptFs) getSIOConfig(key [32]byte) sio.Config { return sio.Config{ MinVersion: sio.Version20, MaxVersion: sio.Version20, Key: key[:], } } // ConvertFileInfo returns a FileInfo with the decrypted size func (fs *CryptFs) ConvertFileInfo(info os.FileInfo) os.FileInfo { return convertCryptFsInfo(info) } func (fs *CryptFs) getFileAndEncryptionKey(name string) (*os.File, [32]byte, error) { var key [32]byte f, err := os.Open(name) if err != nil { return nil, key, err } header := encryptedFileHeader{} err = header.Load(f) if err != nil { f.Close() return nil, key, err } kdf := hkdf.New(sha256.New, fs.masterKey, header.nonce, nil) _, err = io.ReadFull(kdf, key[:]) if err != nil { f.Close() return nil, key, err } return f, key, err } func (*CryptFs) encryptWrapper(dst io.Writer, src io.Reader, config sio.Config) (int64, error) { encReader, err := sio.EncryptReader(src, config) if err != nil { return 0, err } return doCopy(dst, encReader, make([]byte, 65568)) } func (fs *CryptFs) decryptWrapper(dst io.Writer, src io.Reader, config sio.Config) (int64, error) { if fs.readBufferSize <= 0 { return sio.Decrypt(dst, src, config) } br := bufio.NewReaderSize(src, fs.readBufferSize) decReader, err := sio.DecryptReader(br, config) if err != nil { return 0, err } return doCopy(dst, decReader, make([]byte, 65568)) } func isZeroBytesDownload(f *os.File, offset int64) (bool, error) { info, err := f.Stat() if err != nil { return false, err } if info.Size() == headerV10Size { return true, nil } if info.Size() > headerV10Size { decSize, err := sio.DecryptedSize(uint64(info.Size() - headerV10Size)) if err != nil { return false, err } if int64(decSize) == offset { return true, nil } } return false, nil } func convertCryptFsInfo(info os.FileInfo) os.FileInfo { if !info.Mode().IsRegular() { return info } size := info.Size() if size >= headerV10Size { size -= headerV10Size decryptedSize, err := sio.DecryptedSize(uint64(size)) if err == nil { size = int64(decryptedSize) } } else { size = 0 } return NewFileInfo(info.Name(), info.IsDir(), size, info.ModTime(), false) } type encryptedFileHeader struct { version byte nonce []byte } func (h *encryptedFileHeader) Store(f *os.File) error { buf := make([]byte, 0, headerV10Size) buf = append(buf, version10) buf = append(buf, h.nonce...) _, err := f.Write(buf) return err } func (h *encryptedFileHeader) Load(f *os.File) error { header := make([]byte, 1+nonceV10Size) _, err := io.ReadFull(f, header) if err != nil { return err } h.version = header[0] if h.version == version10 { h.nonce = header[1:] return nil } return fmt.Errorf("unsupported encryption version: %v", h.version) } type cryptedFileWrapper struct { *os.File } func (w *cryptedFileWrapper) ReadAt(p []byte, offset int64) (n int, err error) { return w.File.ReadAt(p, offset+headerV10Size) } type cryptFsDirLister struct { f *os.File } func (l *cryptFsDirLister) Next(limit int) ([]os.FileInfo, error) { if limit <= 0 { return nil, errInvalidDirListerLimit } files, err := l.f.Readdir(limit) for idx := range files { files[idx] = convertCryptFsInfo(files[idx]) } return files, err } func (l *cryptFsDirLister) Close() error { return l.f.Close() } ================================================ FILE: internal/vfs/fileinfo.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package vfs import ( "os" "path" "time" "github.com/drakkan/sftpgo/v2/internal/util" ) // FileInfo implements os.FileInfo for a Cloud Storage file. type FileInfo struct { name string sizeInBytes int64 modTime time.Time mode os.FileMode metadata map[string]string } // NewFileInfo creates file info. func NewFileInfo(name string, isDirectory bool, sizeInBytes int64, modTime time.Time, fullName bool) *FileInfo { mode := os.FileMode(0644) if isDirectory { mode = os.FileMode(0755) | os.ModeDir } if !fullName { // we have always Unix style paths here name = path.Base(name) } return &FileInfo{ name: name, sizeInBytes: sizeInBytes, modTime: modTime, mode: mode, } } // Name provides the base name of the file. func (fi *FileInfo) Name() string { return fi.name } // Size provides the length in bytes for a file. func (fi *FileInfo) Size() int64 { return fi.sizeInBytes } // Mode provides the file mode bits func (fi *FileInfo) Mode() os.FileMode { return fi.mode } // ModTime provides the last modification time. func (fi *FileInfo) ModTime() time.Time { return fi.modTime } // IsDir provides the abbreviation for Mode().IsDir() func (fi *FileInfo) IsDir() bool { return fi.mode&os.ModeDir != 0 } // SetMode sets the file mode func (fi *FileInfo) SetMode(mode os.FileMode) { fi.mode = mode } // Sys provides the underlying data source (can return nil) func (fi *FileInfo) Sys() any { return fi.metadata } func (fi *FileInfo) setMetadata(value map[string]string) { fi.metadata = value } func (fi *FileInfo) setMetadataFromPointerVal(value map[string]*string) { if len(value) == 0 { fi.metadata = nil return } fi.metadata = map[string]string{} for k, v := range value { val := util.GetStringFromPointer(v) if val != "" { fi.metadata[k] = val } } } func getMetadata(fi os.FileInfo) map[string]string { if fi.Sys() == nil { return nil } if val, ok := fi.Sys().(map[string]string); ok { if len(val) > 0 { return val } } return nil } ================================================ FILE: internal/vfs/filesystem.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package vfs import ( "os" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/util" ) // Filesystem defines filesystem details type Filesystem struct { RedactedSecret string `json:"-"` Provider sdk.FilesystemProvider `json:"provider"` OSConfig sdk.OSFsConfig `json:"osconfig,omitempty"` S3Config S3FsConfig `json:"s3config,omitempty"` GCSConfig GCSFsConfig `json:"gcsconfig,omitempty"` AzBlobConfig AzBlobFsConfig `json:"azblobconfig,omitempty"` CryptConfig CryptFsConfig `json:"cryptconfig,omitempty"` SFTPConfig SFTPFsConfig `json:"sftpconfig,omitempty"` HTTPConfig HTTPFsConfig `json:"httpconfig,omitempty"` } // SetEmptySecrets sets the secrets to empty func (f *Filesystem) SetEmptySecrets() { f.S3Config.AccessSecret = kms.NewEmptySecret() f.S3Config.SSECustomerKey = kms.NewEmptySecret() f.GCSConfig.Credentials = kms.NewEmptySecret() f.AzBlobConfig.AccountKey = kms.NewEmptySecret() f.AzBlobConfig.SASURL = kms.NewEmptySecret() f.CryptConfig.Passphrase = kms.NewEmptySecret() f.SFTPConfig.Password = kms.NewEmptySecret() f.SFTPConfig.PrivateKey = kms.NewEmptySecret() f.SFTPConfig.KeyPassphrase = kms.NewEmptySecret() f.HTTPConfig.Password = kms.NewEmptySecret() f.HTTPConfig.APIKey = kms.NewEmptySecret() } // SetEmptySecretsIfNil sets the secrets to empty if nil func (f *Filesystem) SetEmptySecretsIfNil() { if f.S3Config.AccessSecret == nil { f.S3Config.AccessSecret = kms.NewEmptySecret() } if f.S3Config.SSECustomerKey == nil { f.S3Config.SSECustomerKey = kms.NewEmptySecret() } if f.GCSConfig.Credentials == nil { f.GCSConfig.Credentials = kms.NewEmptySecret() } if f.AzBlobConfig.AccountKey == nil { f.AzBlobConfig.AccountKey = kms.NewEmptySecret() } if f.AzBlobConfig.SASURL == nil { f.AzBlobConfig.SASURL = kms.NewEmptySecret() } if f.CryptConfig.Passphrase == nil { f.CryptConfig.Passphrase = kms.NewEmptySecret() } if f.SFTPConfig.Password == nil { f.SFTPConfig.Password = kms.NewEmptySecret() } if f.SFTPConfig.PrivateKey == nil { f.SFTPConfig.PrivateKey = kms.NewEmptySecret() } if f.SFTPConfig.KeyPassphrase == nil { f.SFTPConfig.KeyPassphrase = kms.NewEmptySecret() } if f.HTTPConfig.Password == nil { f.HTTPConfig.Password = kms.NewEmptySecret() } if f.HTTPConfig.APIKey == nil { f.HTTPConfig.APIKey = kms.NewEmptySecret() } } // SetNilSecretsIfEmpty set the secrets to nil if empty. // This is useful before rendering as JSON so the empty fields // will not be serialized. func (f *Filesystem) SetNilSecretsIfEmpty() { if f.S3Config.AccessSecret != nil && f.S3Config.AccessSecret.IsEmpty() { f.S3Config.AccessSecret = nil } if f.S3Config.SSECustomerKey != nil && f.S3Config.SSECustomerKey.IsEmpty() { f.S3Config.SSECustomerKey = nil } if f.GCSConfig.Credentials != nil && f.GCSConfig.Credentials.IsEmpty() { f.GCSConfig.Credentials = nil } if f.AzBlobConfig.AccountKey != nil && f.AzBlobConfig.AccountKey.IsEmpty() { f.AzBlobConfig.AccountKey = nil } if f.AzBlobConfig.SASURL != nil && f.AzBlobConfig.SASURL.IsEmpty() { f.AzBlobConfig.SASURL = nil } if f.CryptConfig.Passphrase != nil && f.CryptConfig.Passphrase.IsEmpty() { f.CryptConfig.Passphrase = nil } f.SFTPConfig.setNilSecretsIfEmpty() f.HTTPConfig.setNilSecretsIfEmpty() } // IsEqual returns true if the fs is equal to other func (f *Filesystem) IsEqual(other Filesystem) bool { if f.Provider != other.Provider { return false } switch f.Provider { case sdk.S3FilesystemProvider: return f.S3Config.isEqual(other.S3Config) case sdk.GCSFilesystemProvider: return f.GCSConfig.isEqual(other.GCSConfig) case sdk.AzureBlobFilesystemProvider: return f.AzBlobConfig.isEqual(other.AzBlobConfig) case sdk.CryptedFilesystemProvider: return f.CryptConfig.isEqual(other.CryptConfig) case sdk.SFTPFilesystemProvider: return f.SFTPConfig.isEqual(other.SFTPConfig) case sdk.HTTPFilesystemProvider: return f.HTTPConfig.isEqual(other.HTTPConfig) default: return true } } // IsSameResource returns true if fs point to the same resource as other func (f *Filesystem) IsSameResource(other Filesystem) bool { if f.Provider != other.Provider { return false } switch f.Provider { case sdk.S3FilesystemProvider: return f.S3Config.isSameResource(other.S3Config) case sdk.GCSFilesystemProvider: return f.GCSConfig.isSameResource(other.GCSConfig) case sdk.AzureBlobFilesystemProvider: return f.AzBlobConfig.isSameResource(other.AzBlobConfig) case sdk.CryptedFilesystemProvider: return f.CryptConfig.isSameResource(other.CryptConfig) case sdk.SFTPFilesystemProvider: return f.SFTPConfig.isSameResource(other.SFTPConfig) case sdk.HTTPFilesystemProvider: return f.HTTPConfig.isSameResource(other.HTTPConfig) default: return true } } // GetPathSeparator returns the path separator func (f *Filesystem) GetPathSeparator() string { switch f.Provider { case sdk.LocalFilesystemProvider, sdk.CryptedFilesystemProvider: return string(os.PathSeparator) default: return "/" } } // Validate verifies the FsConfig matching the configured provider and sets all other // Filesystem.*Config to their zero value if successful func (f *Filesystem) Validate(additionalData string) error { switch f.Provider { case sdk.S3FilesystemProvider: if err := f.S3Config.ValidateAndEncryptCredentials(additionalData); err != nil { return err } f.OSConfig = sdk.OSFsConfig{} f.GCSConfig = GCSFsConfig{} f.AzBlobConfig = AzBlobFsConfig{} f.CryptConfig = CryptFsConfig{} f.SFTPConfig = SFTPFsConfig{} f.HTTPConfig = HTTPFsConfig{} return nil case sdk.GCSFilesystemProvider: if err := f.GCSConfig.ValidateAndEncryptCredentials(additionalData); err != nil { return err } f.OSConfig = sdk.OSFsConfig{} f.S3Config = S3FsConfig{} f.AzBlobConfig = AzBlobFsConfig{} f.CryptConfig = CryptFsConfig{} f.SFTPConfig = SFTPFsConfig{} f.HTTPConfig = HTTPFsConfig{} return nil case sdk.AzureBlobFilesystemProvider: if err := f.AzBlobConfig.ValidateAndEncryptCredentials(additionalData); err != nil { return err } f.OSConfig = sdk.OSFsConfig{} f.S3Config = S3FsConfig{} f.GCSConfig = GCSFsConfig{} f.CryptConfig = CryptFsConfig{} f.SFTPConfig = SFTPFsConfig{} f.HTTPConfig = HTTPFsConfig{} return nil case sdk.CryptedFilesystemProvider: if err := f.CryptConfig.ValidateAndEncryptCredentials(additionalData); err != nil { return err } f.OSConfig = sdk.OSFsConfig{} f.S3Config = S3FsConfig{} f.GCSConfig = GCSFsConfig{} f.AzBlobConfig = AzBlobFsConfig{} f.SFTPConfig = SFTPFsConfig{} f.HTTPConfig = HTTPFsConfig{} return validateOSFsConfig(&f.CryptConfig.OSFsConfig) case sdk.SFTPFilesystemProvider: if err := f.SFTPConfig.ValidateAndEncryptCredentials(additionalData); err != nil { return err } f.OSConfig = sdk.OSFsConfig{} f.S3Config = S3FsConfig{} f.GCSConfig = GCSFsConfig{} f.AzBlobConfig = AzBlobFsConfig{} f.CryptConfig = CryptFsConfig{} f.HTTPConfig = HTTPFsConfig{} return nil case sdk.HTTPFilesystemProvider: if err := f.HTTPConfig.ValidateAndEncryptCredentials(additionalData); err != nil { return err } f.OSConfig = sdk.OSFsConfig{} f.S3Config = S3FsConfig{} f.GCSConfig = GCSFsConfig{} f.AzBlobConfig = AzBlobFsConfig{} f.CryptConfig = CryptFsConfig{} f.SFTPConfig = SFTPFsConfig{} return nil case sdk.LocalFilesystemProvider: f.S3Config = S3FsConfig{} f.GCSConfig = GCSFsConfig{} f.AzBlobConfig = AzBlobFsConfig{} f.CryptConfig = CryptFsConfig{} f.SFTPConfig = SFTPFsConfig{} f.HTTPConfig = HTTPFsConfig{} return validateOSFsConfig(&f.OSConfig) default: return util.NewI18nError( util.NewValidationError("invalid filesystem provider"), util.I18nErrorFsValidation, ) } } // HasRedactedSecret returns true if configured the filesystem configuration has a redacted secret func (f *Filesystem) HasRedactedSecret() bool { // TODO move vfs specific code into each *FsConfig struct switch f.Provider { case sdk.S3FilesystemProvider: if f.S3Config.SSECustomerKey.IsRedacted() { return true } return f.S3Config.AccessSecret.IsRedacted() case sdk.GCSFilesystemProvider: return f.GCSConfig.Credentials.IsRedacted() case sdk.AzureBlobFilesystemProvider: if f.AzBlobConfig.AccountKey.IsRedacted() { return true } return f.AzBlobConfig.SASURL.IsRedacted() case sdk.CryptedFilesystemProvider: return f.CryptConfig.Passphrase.IsRedacted() case sdk.SFTPFilesystemProvider: if f.SFTPConfig.Password.IsRedacted() { return true } if f.SFTPConfig.PrivateKey.IsRedacted() { return true } return f.SFTPConfig.KeyPassphrase.IsRedacted() case sdk.HTTPFilesystemProvider: if f.HTTPConfig.Password.IsRedacted() { return true } return f.HTTPConfig.APIKey.IsRedacted() } return false } // HideConfidentialData hides filesystem confidential data func (f *Filesystem) HideConfidentialData() { switch f.Provider { case sdk.S3FilesystemProvider: f.S3Config.HideConfidentialData() case sdk.GCSFilesystemProvider: f.GCSConfig.HideConfidentialData() case sdk.AzureBlobFilesystemProvider: f.AzBlobConfig.HideConfidentialData() case sdk.CryptedFilesystemProvider: f.CryptConfig.HideConfidentialData() case sdk.SFTPFilesystemProvider: f.SFTPConfig.HideConfidentialData() case sdk.HTTPFilesystemProvider: f.HTTPConfig.HideConfidentialData() } } // GetACopy returns a filesystem copy func (f *Filesystem) GetACopy() Filesystem { f.SetEmptySecretsIfNil() fs := Filesystem{ Provider: f.Provider, OSConfig: sdk.OSFsConfig{ ReadBufferSize: f.OSConfig.ReadBufferSize, WriteBufferSize: f.OSConfig.WriteBufferSize, }, S3Config: S3FsConfig{ BaseS3FsConfig: sdk.BaseS3FsConfig{ Bucket: f.S3Config.Bucket, Region: f.S3Config.Region, AccessKey: f.S3Config.AccessKey, RoleARN: f.S3Config.RoleARN, Endpoint: f.S3Config.Endpoint, StorageClass: f.S3Config.StorageClass, ACL: f.S3Config.ACL, KeyPrefix: f.S3Config.KeyPrefix, UploadPartSize: f.S3Config.UploadPartSize, UploadConcurrency: f.S3Config.UploadConcurrency, DownloadPartSize: f.S3Config.DownloadPartSize, DownloadConcurrency: f.S3Config.DownloadConcurrency, DownloadPartMaxTime: f.S3Config.DownloadPartMaxTime, UploadPartMaxTime: f.S3Config.UploadPartMaxTime, ForcePathStyle: f.S3Config.ForcePathStyle, SkipTLSVerify: f.S3Config.SkipTLSVerify, }, AccessSecret: f.S3Config.AccessSecret.Clone(), SSECustomerKey: f.S3Config.SSECustomerKey.Clone(), }, GCSConfig: GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ Bucket: f.GCSConfig.Bucket, AutomaticCredentials: f.GCSConfig.AutomaticCredentials, StorageClass: f.GCSConfig.StorageClass, ACL: f.GCSConfig.ACL, KeyPrefix: f.GCSConfig.KeyPrefix, UploadPartSize: f.GCSConfig.UploadPartSize, UploadPartMaxTime: f.GCSConfig.UploadPartMaxTime, }, Credentials: f.GCSConfig.Credentials.Clone(), }, AzBlobConfig: AzBlobFsConfig{ BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ Container: f.AzBlobConfig.Container, AccountName: f.AzBlobConfig.AccountName, Endpoint: f.AzBlobConfig.Endpoint, KeyPrefix: f.AzBlobConfig.KeyPrefix, UploadPartSize: f.AzBlobConfig.UploadPartSize, UploadConcurrency: f.AzBlobConfig.UploadConcurrency, DownloadPartSize: f.AzBlobConfig.DownloadPartSize, DownloadConcurrency: f.AzBlobConfig.DownloadConcurrency, UseEmulator: f.AzBlobConfig.UseEmulator, AccessTier: f.AzBlobConfig.AccessTier, }, AccountKey: f.AzBlobConfig.AccountKey.Clone(), SASURL: f.AzBlobConfig.SASURL.Clone(), }, CryptConfig: CryptFsConfig{ OSFsConfig: sdk.OSFsConfig{ ReadBufferSize: f.CryptConfig.ReadBufferSize, WriteBufferSize: f.CryptConfig.WriteBufferSize, }, Passphrase: f.CryptConfig.Passphrase.Clone(), }, SFTPConfig: SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: f.SFTPConfig.Endpoint, Username: f.SFTPConfig.Username, Prefix: f.SFTPConfig.Prefix, DisableCouncurrentReads: f.SFTPConfig.DisableCouncurrentReads, BufferSize: f.SFTPConfig.BufferSize, EqualityCheckMode: f.SFTPConfig.EqualityCheckMode, }, Password: f.SFTPConfig.Password.Clone(), PrivateKey: f.SFTPConfig.PrivateKey.Clone(), KeyPassphrase: f.SFTPConfig.KeyPassphrase.Clone(), }, HTTPConfig: HTTPFsConfig{ BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ Endpoint: f.HTTPConfig.Endpoint, Username: f.HTTPConfig.Username, SkipTLSVerify: f.HTTPConfig.SkipTLSVerify, EqualityCheckMode: f.HTTPConfig.EqualityCheckMode, }, Password: f.HTTPConfig.Password.Clone(), APIKey: f.HTTPConfig.APIKey.Clone(), }, } if len(f.SFTPConfig.Fingerprints) > 0 { fs.SFTPConfig.Fingerprints = make([]string, len(f.SFTPConfig.Fingerprints)) copy(fs.SFTPConfig.Fingerprints, f.SFTPConfig.Fingerprints) } return fs } ================================================ FILE: internal/vfs/folder.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package vfs import ( "errors" "fmt" "strings" "github.com/rs/xid" "github.com/sftpgo/sdk" ) // BaseVirtualFolder defines the path for the virtual folder and the used quota limits. // The same folder can be shared among multiple users and each user can have different // quota limits or a different virtual path. type BaseVirtualFolder struct { ID int64 `json:"id"` Name string `json:"name"` MappedPath string `json:"mapped_path,omitempty"` Description string `json:"description,omitempty"` UsedQuotaSize int64 `json:"used_quota_size"` // Used quota as number of files UsedQuotaFiles int `json:"used_quota_files"` // Last quota update as unix timestamp in milliseconds LastQuotaUpdate int64 `json:"last_quota_update"` // list of usernames associated with this virtual folder Users []string `json:"users,omitempty"` // list of group names associated with this virtual folder Groups []string `json:"groups,omitempty"` // Filesystem configuration details FsConfig Filesystem `json:"filesystem"` } // GetEncryptionAdditionalData returns the additional data to use for AEAD func (v *BaseVirtualFolder) GetEncryptionAdditionalData() string { return fmt.Sprintf("folder_%v", v.Name) } // GetACopy returns a copy func (v *BaseVirtualFolder) GetACopy() BaseVirtualFolder { users := make([]string, len(v.Users)) copy(users, v.Users) groups := make([]string, len(v.Groups)) copy(groups, v.Groups) return BaseVirtualFolder{ ID: v.ID, Name: v.Name, Description: v.Description, MappedPath: v.MappedPath, UsedQuotaSize: v.UsedQuotaSize, UsedQuotaFiles: v.UsedQuotaFiles, LastQuotaUpdate: v.LastQuotaUpdate, Users: users, Groups: v.Groups, FsConfig: v.FsConfig.GetACopy(), } } // IsLocalOrLocalCrypted returns true if the folder provider is local or local encrypted func (v *BaseVirtualFolder) IsLocalOrLocalCrypted() bool { return v.FsConfig.Provider == sdk.LocalFilesystemProvider || v.FsConfig.Provider == sdk.CryptedFilesystemProvider } // hideConfidentialData hides folder confidential data func (v *BaseVirtualFolder) hideConfidentialData() { switch v.FsConfig.Provider { case sdk.S3FilesystemProvider: v.FsConfig.S3Config.HideConfidentialData() case sdk.GCSFilesystemProvider: v.FsConfig.GCSConfig.HideConfidentialData() case sdk.AzureBlobFilesystemProvider: v.FsConfig.AzBlobConfig.HideConfidentialData() case sdk.CryptedFilesystemProvider: v.FsConfig.CryptConfig.HideConfidentialData() case sdk.SFTPFilesystemProvider: v.FsConfig.SFTPConfig.HideConfidentialData() case sdk.HTTPFilesystemProvider: v.FsConfig.HTTPConfig.HideConfidentialData() } } // PrepareForRendering prepares a folder for rendering. // It hides confidential data and set to nil the empty secrets // so they are not serialized func (v *BaseVirtualFolder) PrepareForRendering() { v.hideConfidentialData() v.FsConfig.SetEmptySecretsIfNil() } // HasRedactedSecret returns true if the folder has a redacted secret func (v *BaseVirtualFolder) HasRedactedSecret() bool { return v.FsConfig.HasRedactedSecret() } // hasPathPlaceholder returns true if the folder has a path placeholder func (v *BaseVirtualFolder) hasPathPlaceholder() bool { placeholders := []string{"%username%", "%role%"} var config string switch v.FsConfig.Provider { case sdk.S3FilesystemProvider: config = v.FsConfig.S3Config.KeyPrefix case sdk.GCSFilesystemProvider: config = v.FsConfig.GCSConfig.KeyPrefix case sdk.AzureBlobFilesystemProvider: config = v.FsConfig.AzBlobConfig.KeyPrefix case sdk.SFTPFilesystemProvider: config = v.FsConfig.SFTPConfig.Prefix case sdk.LocalFilesystemProvider, sdk.CryptedFilesystemProvider: config = v.MappedPath } for _, placeholder := range placeholders { if strings.Contains(config, placeholder) { return true } } return false } // VirtualFolder defines a mapping between an SFTPGo virtual path and a // filesystem path outside the user home directory. // The specified paths must be absolute and the virtual path cannot be "/", // it must be a sub directory. The parent directory for the specified virtual // path must exist. SFTPGo will try to automatically create any missing // parent directory for the configured virtual folders at user login. type VirtualFolder struct { BaseVirtualFolder VirtualPath string `json:"virtual_path"` // Maximum size allowed as bytes. 0 means unlimited, -1 included in user quota QuotaSize int64 `json:"quota_size"` // Maximum number of files allowed. 0 means unlimited, -1 included in user quota QuotaFiles int `json:"quota_files"` } // GetFilesystem returns the filesystem for this folder func (v *VirtualFolder) GetFilesystem(connectionID string, forbiddenSelfUsers []string) (Fs, error) { switch v.FsConfig.Provider { case sdk.S3FilesystemProvider: return NewS3Fs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.S3Config) case sdk.GCSFilesystemProvider: return NewGCSFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.GCSConfig) case sdk.AzureBlobFilesystemProvider: return NewAzBlobFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.AzBlobConfig) case sdk.CryptedFilesystemProvider: return NewCryptFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.CryptConfig) case sdk.SFTPFilesystemProvider: return NewSFTPFs(connectionID, v.VirtualPath, v.MappedPath, forbiddenSelfUsers, v.FsConfig.SFTPConfig) case sdk.HTTPFilesystemProvider: return NewHTTPFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.HTTPConfig) default: return NewOsFs(connectionID, v.MappedPath, v.VirtualPath, &v.FsConfig.OSConfig), nil } } // ScanQuota scans the folder and returns the number of files and their size func (v *VirtualFolder) ScanQuota() (int, int64, error) { if v.hasPathPlaceholder() { return 0, 0, errors.New("cannot scan quota: this folder has a path placeholder") } fs, err := v.GetFilesystem(xid.New().String(), nil) if err != nil { return 0, 0, err } defer fs.Close() return fs.ScanRootDirContents() } // IsIncludedInUserQuota returns true if the virtual folder is included in user quota func (v *VirtualFolder) IsIncludedInUserQuota() bool { return v.QuotaFiles == -1 && v.QuotaSize == -1 } // HasNoQuotaRestrictions returns true if no quota restrictions need to be applyed func (v *VirtualFolder) HasNoQuotaRestrictions(checkFiles bool) bool { if v.QuotaSize == 0 && (!checkFiles || v.QuotaFiles == 0) { return true } return false } // GetACopy returns a copy func (v *VirtualFolder) GetACopy() VirtualFolder { return VirtualFolder{ BaseVirtualFolder: v.BaseVirtualFolder.GetACopy(), VirtualPath: v.VirtualPath, QuotaSize: v.QuotaSize, QuotaFiles: v.QuotaFiles, } } ================================================ FILE: internal/vfs/gcsfs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !nogcs package vfs import ( "context" "errors" "fmt" "io" "mime" "net/http" "os" "path" "path/filepath" "strconv" "strings" "time" "cloud.google.com/go/storage" "github.com/pkg/sftp" "github.com/rs/xid" "google.golang.org/api/googleapi" "google.golang.org/api/iterator" "google.golang.org/api/option" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( defaultGCSPageSize = 5000 ) var ( gcsDefaultFieldsSelection = []string{"Name", "Size", "Deleted", "Updated", "ContentType", "Metadata"} ) // GCSFs is a Fs implementation for Google Cloud Storage. type GCSFs struct { connectionID string localTempDir string // if not empty this fs is mouted as virtual folder in the specified path mountPath string config *GCSFsConfig svc *storage.Client ctxTimeout time.Duration ctxLongTimeout time.Duration } func init() { version.AddFeature("+gcs") } // NewGCSFs returns an GCSFs object that allows to interact with Google Cloud Storage func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig) (Fs, error) { if localTempDir == "" { localTempDir = getLocalTempDir() } var err error fs := &GCSFs{ connectionID: connectionID, localTempDir: localTempDir, mountPath: getMountPath(mountPath), config: &config, ctxTimeout: 30 * time.Second, ctxLongTimeout: 300 * time.Second, } if err = fs.config.validate(); err != nil { return fs, err } ctx := context.Background() if fs.config.AutomaticCredentials > 0 { fs.svc, err = storage.NewClient(ctx, storage.WithJSONReads(), option.WithUserAgent(version.GetVersionHash()), ) } else { err = fs.config.Credentials.TryDecrypt() if err != nil { return fs, err } fs.svc, err = storage.NewClient(ctx, storage.WithJSONReads(), option.WithUserAgent(version.GetVersionHash()), option.WithAuthCredentialsJSON(option.ServiceAccount, []byte(fs.config.Credentials.GetPayload())), ) } return fs, err } // Name returns the name for the Fs implementation func (fs *GCSFs) Name() string { return fmt.Sprintf("%s bucket %q", gcsfsName, fs.config.Bucket) } // ConnectionID returns the connection ID associated to this Fs implementation func (fs *GCSFs) ConnectionID() string { return fs.connectionID } // Stat returns a FileInfo describing the named file func (fs *GCSFs) Stat(name string) (os.FileInfo, error) { if name == "" || name == "/" || name == "." { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } if fs.config.KeyPrefix == name+"/" { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } return fs.getObjectStat(name) } // Lstat returns a FileInfo describing the named file func (fs *GCSFs) Lstat(name string) (os.FileInfo, error) { return fs.Stat(name) } // Open opens the named file for reading func (fs *GCSFs) Open(name string, offset int64) (File, PipeReader, func(), error) { r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { return nil, nil, nil, err } p := NewPipeReader(r) if readMetadata > 0 { attrs, err := fs.headObject(name) if err != nil { r.Close() w.Close() return nil, nil, nil, err } p.setMetadata(attrs.Metadata) } bkt := fs.svc.Bucket(fs.config.Bucket) obj := bkt.Object(name) ctx, cancelFn := context.WithCancel(context.Background()) objectReader, err := obj.NewRangeReader(ctx, offset, -1) if err == nil && offset > 0 && objectReader.Attrs.ContentEncoding == "gzip" { err = fmt.Errorf("range request is not possible for gzip content encoding, requested offset %d", offset) objectReader.Close() } if err != nil { r.Close() w.Close() cancelFn() return nil, nil, nil, err } go func() { defer cancelFn() defer objectReader.Close() n, err := io.Copy(w, objectReader) w.CloseWithError(err) //nolint:errcheck fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %+v", name, n, err) metric.GCSTransferCompleted(n, 1, err) }() return nil, p, cancelFn, nil } // Create creates or opens the named file for writing func (fs *GCSFs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) { if checks&CheckParentDir != 0 { _, err := fs.Stat(path.Dir(name)) if err != nil { return nil, nil, nil, err } } chunkSize := googleapi.DefaultUploadChunkSize if fs.config.UploadPartSize > 0 { chunkSize = int(fs.config.UploadPartSize) * 1024 * 1024 } r, w, err := createPipeFn(fs.localTempDir, int64(chunkSize+1024*1024)) if err != nil { return nil, nil, nil, err } var partialFileName string var attrs *storage.ObjectAttrs var statErr error bkt := fs.svc.Bucket(fs.config.Bucket) obj := bkt.Object(name) if flag == -1 { obj = obj.If(storage.Conditions{DoesNotExist: true}) } else { attrs, statErr = fs.headObject(name) if statErr == nil { obj = obj.If(storage.Conditions{GenerationMatch: attrs.Generation}) } else if fs.IsNotExist(statErr) { obj = obj.If(storage.Conditions{DoesNotExist: true}) } else { fsLog(fs, logger.LevelWarn, "unable to set precondition for %q, stat err: %v", name, statErr) } } ctx, cancelFn := context.WithCancel(context.Background()) var p PipeWriter var objectWriter *storage.Writer if checks&CheckResume != 0 { if statErr != nil { cancelFn() r.Close() w.Close() return nil, nil, nil, fmt.Errorf("unable to resume %q stat error: %w", name, statErr) } p = newPipeWriterAtOffset(w, attrs.Size) partialFileName = fs.getTempObject(name) partialObj := bkt.Object(partialFileName) partialObj = partialObj.If(storage.Conditions{DoesNotExist: true}) objectWriter = partialObj.NewWriter(ctx) } else { p = NewPipeWriter(w) objectWriter = obj.NewWriter(ctx) } objectWriter.ChunkSize = chunkSize if fs.config.UploadPartMaxTime > 0 { objectWriter.ChunkRetryDeadline = time.Duration(fs.config.UploadPartMaxTime) * time.Second } fs.setWriterAttrs(objectWriter, flag, name) go func() { defer cancelFn() n, err := io.Copy(objectWriter, r) closeErr := objectWriter.Close() if err == nil { err = closeErr } if err == nil && partialFileName != "" { partialObject := bkt.Object(partialFileName) partialObject = partialObject.If(storage.Conditions{GenerationMatch: objectWriter.Attrs().Generation}) err = fs.composeObjects(ctx, obj, partialObject) } r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %q, acl: %q, readed bytes: %v, err: %+v", name, fs.config.ACL, n, err) metric.GCSTransferCompleted(n, 0, err) }() if uploadMode&8 != 0 { return nil, p, nil, nil } return nil, p, cancelFn, nil } // Rename renames (moves) source to target. func (fs *GCSFs) Rename(source, target string, checks int) (int, int64, error) { if source == target { return -1, -1, nil } if checks&CheckParentDir != 0 { _, err := fs.Stat(path.Dir(target)) if err != nil { return -1, -1, err } } fi, err := fs.getObjectStat(source) if err != nil { return -1, -1, err } return fs.renameInternal(source, target, fi, 0, checks&CheckUpdateModTime != 0) } // Remove removes the named file or (empty) directory. func (fs *GCSFs) Remove(name string, isDir bool) error { if isDir { hasContents, err := fs.hasContents(name) if err != nil { return err } if hasContents { return fmt.Errorf("cannot remove non empty directory: %q", name) } if !strings.HasSuffix(name, "/") { name += "/" } } obj := fs.svc.Bucket(fs.config.Bucket).Object(name) attrs, statErr := fs.headObject(name) if statErr == nil { obj = obj.If(storage.Conditions{GenerationMatch: attrs.Generation}) } else { fsLog(fs, logger.LevelWarn, "unable to set precondition for deleting %q, stat err: %v", name, statErr) } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() err := obj.Delete(ctx) if isDir && fs.IsNotExist(err) { // we can have directories without a trailing "/" (created using v2.1.0 and before) ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() err = fs.svc.Bucket(fs.config.Bucket).Object(strings.TrimSuffix(name, "/")).Delete(ctx) } metric.GCSDeleteObjectCompleted(err) return err } // Mkdir creates a new directory with the specified name and default permissions func (fs *GCSFs) Mkdir(name string) error { _, err := fs.Stat(name) if !fs.IsNotExist(err) { return err } return fs.mkdirInternal(name) } // Symlink creates source as a symbolic link to target. func (*GCSFs) Symlink(_, _ string) error { return ErrVfsUnsupported } // Readlink returns the destination of the named symbolic link func (*GCSFs) Readlink(_ string) (string, error) { return "", ErrVfsUnsupported } // Chown changes the numeric uid and gid of the named file. func (*GCSFs) Chown(_ string, _ int, _ int) error { return ErrVfsUnsupported } // Chmod changes the mode of the named file to mode. func (*GCSFs) Chmod(_ string, _ os.FileMode) error { return ErrVfsUnsupported } // Chtimes changes the access and modification times of the named file. func (fs *GCSFs) Chtimes(name string, _, mtime time.Time, isUploading bool) error { if isUploading { return nil } obj := fs.svc.Bucket(fs.config.Bucket).Object(name) attrs, err := fs.headObject(name) if err != nil { return err } obj = obj.If(storage.Conditions{MetagenerationMatch: attrs.Metageneration}) ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() metadata := attrs.Metadata if metadata == nil { metadata = make(map[string]string) } metadata[lastModifiedField] = strconv.FormatInt(mtime.UnixMilli(), 10) objectAttrsToUpdate := storage.ObjectAttrsToUpdate{ Metadata: metadata, } _, err = obj.Update(ctx, objectAttrsToUpdate) return err } // Truncate changes the size of the named file. // Truncate by path is not supported, while truncating an opened // file is handled inside base transfer func (*GCSFs) Truncate(_ string, _ int64) error { return ErrVfsUnsupported } // ReadDir reads the directory named by dirname and returns // a list of directory entries. func (fs *GCSFs) ReadDir(dirname string) (DirLister, error) { // dirname must be already cleaned prefix := fs.getPrefix(dirname) query := &storage.Query{Prefix: prefix, Delimiter: "/"} err := query.SetAttrSelection(gcsDefaultFieldsSelection) if err != nil { return nil, err } bkt := fs.svc.Bucket(fs.config.Bucket) return &gcsDirLister{ bucket: bkt, query: query, timeout: fs.ctxTimeout, prefix: prefix, prefixes: make(map[string]bool), }, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. // Resuming uploads is not supported on GCS func (*GCSFs) IsUploadResumeSupported() bool { return false } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (*GCSFs) IsConditionalUploadResumeSupported(_ int64) bool { return true } // IsAtomicUploadSupported returns true if atomic upload is supported. // S3 uploads are already atomic, we don't need to upload to a temporary // file func (*GCSFs) IsAtomicUploadSupported() bool { return false } // IsNotExist returns a boolean indicating whether the error is known to // report that a file or directory does not exist func (*GCSFs) IsNotExist(err error) bool { if err == nil { return false } if errors.Is(err, storage.ErrObjectNotExist) { return true } var apiErr *googleapi.Error if errors.As(err, &apiErr) { if apiErr.Code == http.StatusNotFound { return true } } return false } // IsPermission returns a boolean indicating whether the error is known to // report that permission is denied. func (*GCSFs) IsPermission(err error) bool { if err == nil { return false } var apiErr *googleapi.Error if errors.As(err, &apiErr) { if apiErr.Code == http.StatusForbidden || apiErr.Code == http.StatusUnauthorized { return true } } return false } // IsNotSupported returns true if the error indicate an unsupported operation func (*GCSFs) IsNotSupported(err error) bool { if err == nil { return false } return errors.Is(err, ErrVfsUnsupported) } // CheckRootPath creates the specified local root directory if it does not exists func (fs *GCSFs) CheckRootPath(username string, uid int, gid int) bool { // we need a local directory for temporary files osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) return osFs.CheckRootPath(username, uid, gid) } // ScanRootDirContents returns the number of files contained in the bucket, // and their size func (fs *GCSFs) ScanRootDirContents() (int, int64, error) { return fs.GetDirSize(fs.config.KeyPrefix) } // GetDirSize returns the number of files and the size for a folder // including any subfolders func (fs *GCSFs) GetDirSize(dirname string) (int, int64, error) { prefix := fs.getPrefix(dirname) numFiles := 0 size := int64(0) query := &storage.Query{Prefix: prefix} err := query.SetAttrSelection(gcsDefaultFieldsSelection) if err != nil { return numFiles, size, err } iteratePage := func(nextPageToken string) (string, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() bkt := fs.svc.Bucket(fs.config.Bucket) it := bkt.Objects(ctx, query) pager := iterator.NewPager(it, defaultGCSPageSize, nextPageToken) var objects []*storage.ObjectAttrs pageToken, err := pager.NextPage(&objects) if err != nil { return pageToken, err } for _, attrs := range objects { if !attrs.Deleted.IsZero() { continue } isDir := strings.HasSuffix(attrs.Name, "/") || attrs.ContentType == dirMimeType if isDir && attrs.Size == 0 { continue } numFiles++ size += attrs.Size } return pageToken, nil } pageToken := "" for { pageToken, err = iteratePage(pageToken) if err != nil { metric.GCSListObjectsCompleted(err) return numFiles, size, err } fsLog(fs, logger.LevelDebug, "scan in progress for %q, files: %d, size: %d", dirname, numFiles, size) if pageToken == "" { break } } metric.GCSListObjectsCompleted(nil) return numFiles, size, err } // GetAtomicUploadPath returns the path to use for an atomic upload. // GCS uploads are already atomic, we never call this method for GCS func (*GCSFs) GetAtomicUploadPath(_ string) string { return "" } // GetRelativePath returns the path for a file relative to the user's home dir. // This is the path as seen by SFTPGo users func (fs *GCSFs) GetRelativePath(name string) string { rel := path.Clean(name) if rel == "." { rel = "" } if !path.IsAbs(rel) { rel = "/" + rel } if fs.config.KeyPrefix != "" { if !strings.HasPrefix(rel, "/"+fs.config.KeyPrefix) { rel = "/" } rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) } if fs.mountPath != "" { rel = path.Join(fs.mountPath, rel) } return rel } // Walk walks the file tree rooted at root, calling walkFn for each file or // directory in the tree, including root func (fs *GCSFs) Walk(root string, walkFn filepath.WalkFunc) error { prefix := fs.getPrefix(root) query := &storage.Query{Prefix: prefix} err := query.SetAttrSelection(gcsDefaultFieldsSelection) if err != nil { walkFn(root, nil, err) //nolint:errcheck return err } iteratePage := func(nextPageToken string) (string, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() bkt := fs.svc.Bucket(fs.config.Bucket) it := bkt.Objects(ctx, query) pager := iterator.NewPager(it, defaultGCSPageSize, nextPageToken) var objects []*storage.ObjectAttrs pageToken, err := pager.NextPage(&objects) if err != nil { walkFn(root, nil, err) //nolint:errcheck return pageToken, err } for _, attrs := range objects { if !attrs.Deleted.IsZero() { continue } name, isDir := fs.resolve(attrs.Name, prefix, attrs.ContentType) if name == "" { continue } objectModTime := attrs.Updated if val := getLastModified(attrs.Metadata); val > 0 { objectModTime = util.GetTimeFromMsecSinceEpoch(val) } err = walkFn(attrs.Name, NewFileInfo(name, isDir, attrs.Size, objectModTime, false), nil) if err != nil { return pageToken, err } } return pageToken, nil } pageToken := "" for { pageToken, err = iteratePage(pageToken) if err != nil { metric.GCSListObjectsCompleted(err) return err } if pageToken == "" { break } } walkFn(root, NewFileInfo(root, true, 0, time.Unix(0, 0), false), err) //nolint:errcheck metric.GCSListObjectsCompleted(err) return err } // Join joins any number of path elements into a single path func (*GCSFs) Join(elem ...string) string { return strings.TrimPrefix(path.Join(elem...), "/") } // HasVirtualFolders returns true if folders are emulated func (*GCSFs) HasVirtualFolders() bool { return true } // ResolvePath returns the matching filesystem path for the specified virtual path func (fs *GCSFs) ResolvePath(virtualPath string) (string, error) { if fs.mountPath != "" { if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { virtualPath = after } } virtualPath = path.Clean("/" + virtualPath) return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil } // CopyFile implements the FsFileCopier interface func (fs *GCSFs) CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) { numFiles := 1 sizeDiff := srcInfo.Size() var conditions *storage.Conditions attrs, err := fs.headObject(target) if err == nil { sizeDiff -= attrs.Size numFiles = 0 conditions = &storage.Conditions{GenerationMatch: attrs.Generation} } else { if !fs.IsNotExist(err) { return 0, 0, err } conditions = &storage.Conditions{DoesNotExist: true} } if err := fs.copyFileInternal(source, target, conditions, srcInfo, true); err != nil { return 0, 0, err } return numFiles, sizeDiff, nil } func (fs *GCSFs) resolve(name, prefix, contentType string) (string, bool) { result := strings.TrimPrefix(name, prefix) isDir := strings.HasSuffix(result, "/") if isDir { result = strings.TrimSuffix(result, "/") } if contentType == dirMimeType { isDir = true } return result, isDir } // getObjectStat returns the stat result func (fs *GCSFs) getObjectStat(name string) (os.FileInfo, error) { attrs, err := fs.headObject(name) if err == nil { objSize := attrs.Size objectModTime := attrs.Updated if val := getLastModified(attrs.Metadata); val > 0 { objectModTime = util.GetTimeFromMsecSinceEpoch(val) } isDir := attrs.ContentType == dirMimeType || strings.HasSuffix(attrs.Name, "/") info := NewFileInfo(name, isDir, objSize, objectModTime, false) if !isDir { info.setMetadata(attrs.Metadata) } return info, nil } if !fs.IsNotExist(err) { return nil, err } // now check if this is a prefix (virtual directory) hasContents, err := fs.hasContents(name) if err != nil { return nil, err } if hasContents { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } // finally check if this is an object with a trailing / attrs, err = fs.headObject(name + "/") if err != nil { return nil, err } objectModTime := attrs.Updated if val := getLastModified(attrs.Metadata); val > 0 { objectModTime = util.GetTimeFromMsecSinceEpoch(val) } return NewFileInfo(name, true, attrs.Size, objectModTime, false), nil } func (fs *GCSFs) setWriterAttrs(objectWriter *storage.Writer, flag int, name string) { var contentType string if flag == -1 { contentType = dirMimeType } else { contentType = mime.TypeByExtension(path.Ext(name)) } if contentType != "" { objectWriter.ContentType = contentType } if fs.config.StorageClass != "" { objectWriter.StorageClass = fs.config.StorageClass } if fs.config.ACL != "" { objectWriter.PredefinedACL = fs.config.ACL } } func (fs *GCSFs) composeObjects(ctx context.Context, dst, partialObject *storage.ObjectHandle) error { fsLog(fs, logger.LevelDebug, "start object compose for partial file %q, destination %q", partialObject.ObjectName(), dst.ObjectName()) composer := dst.ComposerFrom(dst, partialObject) if fs.config.StorageClass != "" { composer.StorageClass = fs.config.StorageClass } if fs.config.ACL != "" { composer.PredefinedACL = fs.config.ACL } contentType := mime.TypeByExtension(path.Ext(dst.ObjectName())) if contentType != "" { composer.ContentType = contentType } _, err := composer.Run(ctx) fsLog(fs, logger.LevelDebug, "object compose for %q finished, err: %v", dst.ObjectName(), err) delCtx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() errDelete := partialObject.Delete(delCtx) metric.GCSDeleteObjectCompleted(errDelete) fsLog(fs, logger.LevelDebug, "deleted partial file %q after composing with %q, err: %v", partialObject.ObjectName(), dst.ObjectName(), errDelete) return err } func (fs *GCSFs) copyFileInternal(source, target string, conditions *storage.Conditions, srcInfo os.FileInfo, updateModTime bool, ) error { src := fs.svc.Bucket(fs.config.Bucket).Object(source) dst := fs.svc.Bucket(fs.config.Bucket).Object(target) if conditions != nil { dst = dst.If(*conditions) } else { attrs, err := fs.headObject(target) if err == nil { dst = dst.If(storage.Conditions{GenerationMatch: attrs.Generation}) } else if fs.IsNotExist(err) { dst = dst.If(storage.Conditions{DoesNotExist: true}) } else { fsLog(fs, logger.LevelWarn, "unable to set precondition for copy, target %q, stat err: %v", target, err) } } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout)) defer cancelFn() copier := dst.CopierFrom(src) if fs.config.StorageClass != "" { copier.StorageClass = fs.config.StorageClass } if fs.config.ACL != "" { copier.PredefinedACL = fs.config.ACL } contentType := mime.TypeByExtension(path.Ext(source)) if contentType != "" { copier.ContentType = contentType } metadata := getMetadata(srcInfo) if updateModTime && len(metadata) > 0 { delete(metadata, lastModifiedField) } if len(metadata) > 0 { copier.Metadata = metadata } _, err := copier.Run(ctx) metric.GCSCopyObjectCompleted(err) return err } func (fs *GCSFs) renameInternal(source, target string, srcInfo os.FileInfo, recursion int, updateModTime bool, ) (int, int64, error) { var numFiles int var filesSize int64 if srcInfo.IsDir() { if renameMode == 0 { hasContents, err := fs.hasContents(source) if err != nil { return numFiles, filesSize, err } if hasContents { return numFiles, filesSize, fmt.Errorf("%w: cannot rename non empty directory: %q", ErrVfsUnsupported, source) } } if err := fs.mkdirInternal(target); err != nil { return numFiles, filesSize, err } if renameMode == 1 { files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion, updateModTime) numFiles += files filesSize += size if err != nil { return numFiles, filesSize, err } } } else { if err := fs.copyFileInternal(source, target, nil, srcInfo, updateModTime); err != nil { return numFiles, filesSize, err } numFiles++ filesSize += srcInfo.Size() } err := fs.Remove(source, srcInfo.IsDir()) if fs.IsNotExist(err) { err = nil } return numFiles, filesSize, err } func (fs *GCSFs) mkdirInternal(name string) error { if !strings.HasSuffix(name, "/") { name += "/" } _, w, _, err := fs.Create(name, -1, 0) if err != nil { return err } return w.Close() } func (fs *GCSFs) hasContents(name string) (bool, error) { result := false prefix := fs.getPrefix(name) query := &storage.Query{Prefix: prefix} err := query.SetAttrSelection(gcsDefaultFieldsSelection) if err != nil { return result, err } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() bkt := fs.svc.Bucket(fs.config.Bucket) it := bkt.Objects(ctx, query) // if we have a dir object with a trailing slash it will be returned so we set the size to 2 pager := iterator.NewPager(it, 2, "") var objects []*storage.ObjectAttrs _, err = pager.NextPage(&objects) if err != nil { metric.GCSListObjectsCompleted(err) return result, err } for _, attrs := range objects { name, _ := fs.resolve(attrs.Name, prefix, attrs.ContentType) // a dir object with a trailing slash will result in an empty name if name == "/" || name == "" { continue } result = true break } metric.GCSListObjectsCompleted(nil) return result, nil } func (fs *GCSFs) getPrefix(name string) string { prefix := "" if name != "" && name != "." && name != "/" { prefix = strings.TrimPrefix(name, "/") if !strings.HasSuffix(prefix, "/") { prefix += "/" } } return prefix } func (fs *GCSFs) headObject(name string) (*storage.ObjectAttrs, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() bkt := fs.svc.Bucket(fs.config.Bucket) obj := bkt.Object(name) attrs, err := obj.Attrs(ctx) metric.GCSHeadObjectCompleted(err) return attrs, err } // GetMimeType returns the content type func (fs *GCSFs) GetMimeType(name string) (string, error) { attrs, err := fs.headObject(name) if err != nil { return "", err } return attrs.ContentType, nil } // Close closes the fs func (fs *GCSFs) Close() error { return nil } // GetAvailableDiskSize returns the available size for the specified path func (*GCSFs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) { return nil, ErrStorageSizeUnavailable } func (*GCSFs) getTempObject(name string) string { dir := filepath.Dir(name) guid := xid.New().String() return filepath.Join(dir, ".sftpgo-partial."+guid+"."+filepath.Base(name)) } type gcsDirLister struct { baseDirLister bucket *storage.BucketHandle query *storage.Query timeout time.Duration nextPageToken string noMorePages bool prefix string prefixes map[string]bool metricUpdated bool } func (l *gcsDirLister) resolve(name, contentType string) (string, bool) { result := strings.TrimPrefix(name, l.prefix) isDir := strings.HasSuffix(result, "/") if isDir { result = strings.TrimSuffix(result, "/") } if contentType == dirMimeType { isDir = true } return result, isDir } func (l *gcsDirLister) Next(limit int) ([]os.FileInfo, error) { if limit <= 0 { return nil, errInvalidDirListerLimit } if len(l.cache) >= limit { return l.returnFromCache(limit), nil } if l.noMorePages { if !l.metricUpdated { l.metricUpdated = true metric.GCSListObjectsCompleted(nil) } return l.returnFromCache(limit), io.EOF } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) defer cancelFn() it := l.bucket.Objects(ctx, l.query) paginator := iterator.NewPager(it, defaultGCSPageSize, l.nextPageToken) var objects []*storage.ObjectAttrs pageToken, err := paginator.NextPage(&objects) if err != nil { metric.GCSListObjectsCompleted(err) return l.cache, err } for _, attrs := range objects { if attrs.Prefix != "" { name, _ := l.resolve(attrs.Prefix, attrs.ContentType) if name == "" { continue } if _, ok := l.prefixes[name]; ok { continue } l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) l.prefixes[name] = true } else { name, isDir := l.resolve(attrs.Name, attrs.ContentType) if name == "" { continue } if !attrs.Deleted.IsZero() { continue } if isDir { // check if the dir is already included, it will be sent as blob prefix if it contains at least one item if _, ok := l.prefixes[name]; ok { continue } l.prefixes[name] = true } modTime := attrs.Updated if val := getLastModified(attrs.Metadata); val > 0 { modTime = util.GetTimeFromMsecSinceEpoch(val) } info := NewFileInfo(name, isDir, attrs.Size, modTime, false) info.setMetadata(attrs.Metadata) l.cache = append(l.cache, info) } } l.nextPageToken = pageToken l.noMorePages = (l.nextPageToken == "") return l.returnFromCache(limit), nil } func (l *gcsDirLister) Close() error { clear(l.prefixes) return l.baseDirLister.Close() } ================================================ FILE: internal/vfs/gcsfs_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build nogcs package vfs import ( "errors" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-gcs") } // NewGCSFs returns an error, GCS is disabled func NewGCSFs(_, _, _ string, _ GCSFsConfig) (Fs, error) { return nil, errors.New("Google Cloud Storage disabled at build time") } ================================================ FILE: internal/vfs/httpfs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package vfs import ( "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "io/fs" "mime" "net" "net/http" "net/url" "os" "path" "path/filepath" "strings" "time" "github.com/pkg/sftp" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( // httpFsName is the name for the HTTP Fs implementation httpFsName = "httpfs" maxHTTPFsResponseSize = 1048576 ) var ( supportedEndpointSchema = []string{"http://", "https://"} ) // HTTPFsConfig defines the configuration for HTTP based filesystem type HTTPFsConfig struct { sdk.BaseHTTPFsConfig Password *kms.Secret `json:"password,omitempty"` APIKey *kms.Secret `json:"api_key,omitempty"` } func (c *HTTPFsConfig) isUnixDomainSocket() bool { return strings.HasPrefix(c.Endpoint, "http://unix") || strings.HasPrefix(c.Endpoint, "https://unix") } // HideConfidentialData hides confidential data func (c *HTTPFsConfig) HideConfidentialData() { if c.Password != nil { c.Password.Hide() } if c.APIKey != nil { c.APIKey.Hide() } } func (c *HTTPFsConfig) setNilSecretsIfEmpty() { if c.Password != nil && c.Password.IsEmpty() { c.Password = nil } if c.APIKey != nil && c.APIKey.IsEmpty() { c.APIKey = nil } } func (c *HTTPFsConfig) setEmptyCredentialsIfNil() { if c.Password == nil { c.Password = kms.NewEmptySecret() } if c.APIKey == nil { c.APIKey = kms.NewEmptySecret() } } func (c *HTTPFsConfig) isEqual(other HTTPFsConfig) bool { if c.Endpoint != other.Endpoint { return false } if c.Username != other.Username { return false } if c.SkipTLSVerify != other.SkipTLSVerify { return false } c.setEmptyCredentialsIfNil() other.setEmptyCredentialsIfNil() if !c.Password.IsEqual(other.Password) { return false } return c.APIKey.IsEqual(other.APIKey) } func (c *HTTPFsConfig) isSameResource(other HTTPFsConfig) bool { if c.EqualityCheckMode > 0 || other.EqualityCheckMode > 0 { if c.Username != other.Username { return false } } return c.Endpoint == other.Endpoint } // validate returns an error if the configuration is not valid func (c *HTTPFsConfig) validate() error { c.setEmptyCredentialsIfNil() if c.Endpoint == "" { return util.NewI18nError(errors.New("httpfs: endpoint cannot be empty"), util.I18nErrorEndpointRequired) } c.Endpoint = strings.TrimRight(c.Endpoint, "/") endpointURL, err := url.Parse(c.Endpoint) if err != nil { return util.NewI18nError(fmt.Errorf("httpfs: invalid endpoint: %w", err), util.I18nErrorEndpointInvalid) } if !util.IsStringPrefixInSlice(c.Endpoint, supportedEndpointSchema) { return util.NewI18nError( errors.New("httpfs: invalid endpoint schema: http and https are supported"), util.I18nErrorEndpointInvalid, ) } if endpointURL.Host == "unix" { socketPath := endpointURL.Query().Get("socket_path") if !filepath.IsAbs(socketPath) { return util.NewI18nError( fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath), util.I18nErrorEndpointInvalid, ) } } if !isEqualityCheckModeValid(c.EqualityCheckMode) { return errors.New("invalid equality_check_mode") } if c.Password.IsEncrypted() && !c.Password.IsValid() { return errors.New("httpfs: invalid encrypted password") } if !c.Password.IsEmpty() && !c.Password.IsValidInput() { return errors.New("httpfs: invalid password") } if c.APIKey.IsEncrypted() && !c.APIKey.IsValid() { return errors.New("httpfs: invalid encrypted API key") } if !c.APIKey.IsEmpty() && !c.APIKey.IsValidInput() { return errors.New("httpfs: invalid API key") } return nil } // ValidateAndEncryptCredentials validates the config and encrypts credentials if they are in plain text func (c *HTTPFsConfig) ValidateAndEncryptCredentials(additionalData string) error { err := c.validate() if err != nil { var errI18n *util.I18nError errValidation := util.NewValidationError(fmt.Sprintf("could not validate HTTP fs config: %v", err)) if errors.As(err, &errI18n) { return util.NewI18nError(errValidation, errI18n.Message) } return util.NewI18nError(errValidation, util.I18nErrorFsValidation) } if c.Password.IsPlain() { c.Password.SetAdditionalData(additionalData) if err := c.Password.Encrypt(); err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt HTTP fs password: %v", err)), util.I18nErrorFsValidation, ) } } if c.APIKey.IsPlain() { c.APIKey.SetAdditionalData(additionalData) if err := c.APIKey.Encrypt(); err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt HTTP fs API key: %v", err)), util.I18nErrorFsValidation, ) } } return nil } // HTTPFs is a Fs implementation for the SFTPGo HTTP filesystem backend type HTTPFs struct { connectionID string localTempDir string // if not empty this fs is mouted as virtual folder in the specified path mountPath string config *HTTPFsConfig client *http.Client ctxTimeout time.Duration } // NewHTTPFs returns an HTTPFs object that allows to interact with SFTPGo HTTP filesystem backends func NewHTTPFs(connectionID, localTempDir, mountPath string, config HTTPFsConfig) (Fs, error) { if localTempDir == "" { localTempDir = getLocalTempDir() } config.setEmptyCredentialsIfNil() if !config.Password.IsEmpty() { if err := config.Password.TryDecrypt(); err != nil { return nil, err } } if !config.APIKey.IsEmpty() { if err := config.APIKey.TryDecrypt(); err != nil { return nil, err } } fs := &HTTPFs{ connectionID: connectionID, localTempDir: localTempDir, mountPath: mountPath, config: &config, ctxTimeout: 30 * time.Second, } transport := http.DefaultTransport.(*http.Transport).Clone() transport.MaxResponseHeaderBytes = 1 << 16 transport.WriteBufferSize = 1 << 16 transport.ReadBufferSize = 1 << 16 if fs.config.isUnixDomainSocket() { endpointURL, err := url.Parse(fs.config.Endpoint) if err != nil { return nil, err } if endpointURL.Host == "unix" { socketPath := endpointURL.Query().Get("socket_path") if !filepath.IsAbs(socketPath) { return nil, fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath) } if endpointURL.Scheme == "https" { transport.DialTLSContext = func(ctx context.Context, _, _ string) (net.Conn, error) { var tlsConfig *tls.Config var d tls.Dialer if config.SkipTLSVerify { tlsConfig = getInsecureTLSConfig() } d.Config = tlsConfig return d.DialContext(ctx, "unix", socketPath) } } else { transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { var d net.Dialer return d.DialContext(ctx, "unix", socketPath) } } endpointURL.Path = path.Join(endpointURL.Path, endpointURL.Query().Get("api_prefix")) endpointURL.RawQuery = "" endpointURL.RawFragment = "" fs.config.Endpoint = endpointURL.String() } } if config.SkipTLSVerify { if transport.TLSClientConfig != nil { transport.TLSClientConfig.InsecureSkipVerify = true } else { transport.TLSClientConfig = getInsecureTLSConfig() } } fs.client = &http.Client{ Transport: transport, } return fs, nil } // Name returns the name for the Fs implementation func (fs *HTTPFs) Name() string { return fmt.Sprintf("%v %q", httpFsName, fs.config.Endpoint) } // ConnectionID returns the connection ID associated to this Fs implementation func (fs *HTTPFs) ConnectionID() string { return fs.connectionID } // Stat returns a FileInfo describing the named file func (fs *HTTPFs) Stat(name string) (os.FileInfo, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "stat", name, "", "", nil) if err != nil { return nil, err } defer resp.Body.Close() respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize)) if err != nil { return nil, err } var response statResponse err = json.Unmarshal(respBody, &response) if err != nil { return nil, err } return response.getFileInfo(), nil } // Lstat returns a FileInfo describing the named file func (fs *HTTPFs) Lstat(name string) (os.FileInfo, error) { return fs.Stat(name) } // Open opens the named file for reading func (fs *HTTPFs) Open(name string, offset int64) (File, PipeReader, func(), error) { r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { return nil, nil, nil, err } p := NewPipeReader(r) ctx, cancelFn := context.WithCancel(context.Background()) var queryString string if offset > 0 { queryString = fmt.Sprintf("?offset=%d", offset) } go func() { defer cancelFn() resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "open", name, queryString, "", nil) if err != nil { fsLog(fs, logger.LevelError, "download error, path %q, err: %v", name, err) w.CloseWithError(err) //nolint:errcheck metric.HTTPFsTransferCompleted(0, 1, err) return } defer resp.Body.Close() n, err := io.Copy(w, resp.Body) w.CloseWithError(err) //nolint:errcheck fsLog(fs, logger.LevelDebug, "download completed, path %q size: %v, err: %+v", name, n, err) metric.HTTPFsTransferCompleted(n, 1, err) }() return nil, p, cancelFn, nil } // Create creates or opens the named file for writing func (fs *HTTPFs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) { r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { return nil, nil, nil, err } p := NewPipeWriter(w) ctx, cancelFn := context.WithCancel(context.Background()) go func() { defer cancelFn() contentType := mime.TypeByExtension(path.Ext(name)) queryString := fmt.Sprintf("?flags=%d&checks=%d", flag, checks) resp, err := fs.sendHTTPRequest(ctx, http.MethodPost, "create", name, queryString, contentType, &wrapReader{reader: r}) if err != nil { fsLog(fs, logger.LevelError, "upload error, path %q, err: %v", name, err) r.CloseWithError(err) //nolint:errcheck p.Done(err) metric.HTTPFsTransferCompleted(0, 0, err) return } defer resp.Body.Close() r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %d", name, r.GetReadedBytes()) metric.HTTPFsTransferCompleted(r.GetReadedBytes(), 0, err) }() return nil, p, cancelFn, nil } // Rename renames (moves) source to target. func (fs *HTTPFs) Rename(source, target string, checks int) (int, int64, error) { if source == target { return -1, -1, nil } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() queryString := fmt.Sprintf("?target=%s", url.QueryEscape(target)) resp, err := fs.sendHTTPRequest(ctx, http.MethodPatch, "rename", source, queryString, "", nil) if err != nil { return -1, -1, err } defer resp.Body.Close() if checks&CheckUpdateModTime != 0 { fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck } return -1, -1, nil } // Remove removes the named file or (empty) directory. func (fs *HTTPFs) Remove(name string, _ bool) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := fs.sendHTTPRequest(ctx, http.MethodDelete, "remove", name, "", "", nil) if err != nil { return err } defer resp.Body.Close() return nil } // Mkdir creates a new directory with the specified name and default permissions func (fs *HTTPFs) Mkdir(name string) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := fs.sendHTTPRequest(ctx, http.MethodPost, "mkdir", name, "", "", nil) if err != nil { return err } defer resp.Body.Close() return nil } // Symlink creates source as a symbolic link to target. func (*HTTPFs) Symlink(_, _ string) error { return ErrVfsUnsupported } // Readlink returns the destination of the named symbolic link func (*HTTPFs) Readlink(_ string) (string, error) { return "", ErrVfsUnsupported } // Chown changes the numeric uid and gid of the named file. func (fs *HTTPFs) Chown(_ string, _ int, _ int) error { return ErrVfsUnsupported } // Chmod changes the mode of the named file to mode. func (fs *HTTPFs) Chmod(name string, mode os.FileMode) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() queryString := fmt.Sprintf("?mode=%d", mode) resp, err := fs.sendHTTPRequest(ctx, http.MethodPatch, "chmod", name, queryString, "", nil) if err != nil { return err } defer resp.Body.Close() return nil } // Chtimes changes the access and modification times of the named file. func (fs *HTTPFs) Chtimes(name string, atime, mtime time.Time, _ bool) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() queryString := fmt.Sprintf("?access_time=%s&modification_time=%s", atime.UTC().Format(time.RFC3339), mtime.UTC().Format(time.RFC3339)) resp, err := fs.sendHTTPRequest(ctx, http.MethodPatch, "chtimes", name, queryString, "", nil) if err != nil { return err } defer resp.Body.Close() return nil } // Truncate changes the size of the named file. // Truncate by path is not supported, while truncating an opened // file is handled inside base transfer func (fs *HTTPFs) Truncate(name string, size int64) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() queryString := fmt.Sprintf("?size=%d", size) resp, err := fs.sendHTTPRequest(ctx, http.MethodPatch, "truncate", name, queryString, "", nil) if err != nil { return err } defer resp.Body.Close() return nil } // ReadDir reads the directory named by dirname and returns // a list of directory entries. func (fs *HTTPFs) ReadDir(dirname string) (DirLister, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "readdir", dirname, "", "", nil) if err != nil { return nil, err } defer resp.Body.Close() respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize*10)) if err != nil { return nil, err } var response []statResponse err = json.Unmarshal(respBody, &response) if err != nil { return nil, err } result := make([]os.FileInfo, 0, len(response)) for _, stat := range response { result = append(result, stat.getFileInfo()) } return &baseDirLister{result}, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. func (*HTTPFs) IsUploadResumeSupported() bool { return false } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (*HTTPFs) IsConditionalUploadResumeSupported(_ int64) bool { return false } // IsAtomicUploadSupported returns true if atomic upload is supported. func (*HTTPFs) IsAtomicUploadSupported() bool { return false } // IsNotExist returns a boolean indicating whether the error is known to // report that a file or directory does not exist func (*HTTPFs) IsNotExist(err error) bool { return errors.Is(err, fs.ErrNotExist) } // IsPermission returns a boolean indicating whether the error is known to // report that permission is denied. func (*HTTPFs) IsPermission(err error) bool { return errors.Is(err, fs.ErrPermission) } // IsNotSupported returns true if the error indicate an unsupported operation func (*HTTPFs) IsNotSupported(err error) bool { if err == nil { return false } return err == ErrVfsUnsupported } // CheckRootPath creates the specified local root directory if it does not exists func (fs *HTTPFs) CheckRootPath(username string, uid int, gid int) bool { // we need a local directory for temporary files osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) return osFs.CheckRootPath(username, uid, gid) } // ScanRootDirContents returns the number of files and their size func (fs *HTTPFs) ScanRootDirContents() (int, int64, error) { return fs.GetDirSize("/") } // CheckMetadata checks the metadata consistency func (*HTTPFs) CheckMetadata() error { return nil } // GetDirSize returns the number of files and the size for a folder // including any subfolders func (fs *HTTPFs) GetDirSize(dirname string) (int, int64, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "dirsize", dirname, "", "", nil) if err != nil { return 0, 0, err } defer resp.Body.Close() respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize)) if err != nil { return 0, 0, err } var response dirSizeResponse err = json.Unmarshal(respBody, &response) if err != nil { return 0, 0, err } return response.Files, response.Size, nil } // GetAtomicUploadPath returns the path to use for an atomic upload. func (*HTTPFs) GetAtomicUploadPath(_ string) string { return "" } // GetRelativePath returns the path for a file relative to the user's home dir. // This is the path as seen by SFTPGo users func (fs *HTTPFs) GetRelativePath(name string) string { rel := path.Clean(name) if rel == "." { rel = "" } if !path.IsAbs(rel) { rel = "/" + rel } if fs.mountPath != "" { rel = path.Join(fs.mountPath, rel) } return rel } // Walk walks the file tree rooted at root, calling walkFn for each file or // directory in the tree, including root. The result are unordered func (fs *HTTPFs) Walk(root string, walkFn filepath.WalkFunc) error { info, err := fs.Lstat(root) if err != nil { return walkFn(root, nil, err) } return fs.walk(root, info, walkFn) } // Join joins any number of path elements into a single path func (*HTTPFs) Join(elem ...string) string { return strings.TrimPrefix(path.Join(elem...), "/") } // HasVirtualFolders returns true if folders are emulated func (*HTTPFs) HasVirtualFolders() bool { return false } // ResolvePath returns the matching filesystem path for the specified virtual path func (fs *HTTPFs) ResolvePath(virtualPath string) (string, error) { if fs.mountPath != "" { if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { virtualPath = after } } return path.Clean("/" + virtualPath), nil } // GetMimeType returns the content type func (fs *HTTPFs) GetMimeType(name string) (string, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "stat", name, "", "", nil) if err != nil { return "", err } defer resp.Body.Close() respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize)) if err != nil { return "", err } var response mimeTypeResponse err = json.Unmarshal(respBody, &response) if err != nil { return "", err } return response.Mime, nil } // Close closes the fs func (fs *HTTPFs) Close() error { fs.client.CloseIdleConnections() return nil } // GetAvailableDiskSize returns the available size for the specified path func (fs *HTTPFs) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "statvfs", dirName, "", "", nil) if err != nil { return nil, err } defer resp.Body.Close() respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize)) if err != nil { return nil, err } var response statVFSResponse err = json.Unmarshal(respBody, &response) if err != nil { return nil, err } return response.toSFTPStatVFS(), nil } func (fs *HTTPFs) sendHTTPRequest(ctx context.Context, method, base, name, queryString, contentType string, body io.Reader, ) (*http.Response, error) { url := fmt.Sprintf("%s/%s/%s%s", fs.config.Endpoint, base, url.PathEscape(name), queryString) req, err := http.NewRequest(method, url, body) if err != nil { return nil, err } if contentType != "" { req.Header.Set("Content-Type", contentType) } if fs.config.APIKey.GetPayload() != "" { req.Header.Set("X-API-KEY", fs.config.APIKey.GetPayload()) } if fs.config.Username != "" || fs.config.Password.GetPayload() != "" { req.SetBasicAuth(fs.config.Username, fs.config.Password.GetPayload()) } resp, err := fs.client.Do(req.WithContext(ctx)) if err != nil { return nil, fmt.Errorf("unable to send HTTP request to URL %v: %w", url, err) } if err = getErrorFromResponseCode(resp.StatusCode); err != nil { resp.Body.Close() return nil, err } return resp, nil } // walk recursively descends path, calling walkFn. func (fs *HTTPFs) walk(filePath string, info fs.FileInfo, walkFn filepath.WalkFunc) error { if !info.IsDir() { return walkFn(filePath, info, nil) } lister, err := fs.ReadDir(filePath) err1 := walkFn(filePath, info, err) if err != nil || err1 != nil { if err == nil { lister.Close() } return err1 } defer lister.Close() for { files, err := lister.Next(ListerBatchSize) finished := errors.Is(err, io.EOF) if err != nil && !finished { return err } for _, fi := range files { objName := path.Join(filePath, fi.Name()) err = fs.walk(objName, fi, walkFn) if err != nil { return err } } if finished { return nil } } } func getErrorFromResponseCode(code int) error { switch code { case 401, 403: return os.ErrPermission case 404: return os.ErrNotExist case 501: return ErrVfsUnsupported case 200, 201: return nil default: return fmt.Errorf("unexpected response code: %v", code) } } func getInsecureTLSConfig() *tls.Config { return &tls.Config{ InsecureSkipVerify: true, } } type wrapReader struct { reader io.Reader } func (r *wrapReader) Read(p []byte) (n int, err error) { return r.reader.Read(p) } type statResponse struct { Name string `json:"name"` Size int64 `json:"size"` Mode uint32 `json:"mode"` LastModified time.Time `json:"last_modified"` } func (s *statResponse) getFileInfo() os.FileInfo { info := NewFileInfo(s.Name, false, s.Size, s.LastModified, false) info.SetMode(fs.FileMode(s.Mode)) return info } type dirSizeResponse struct { Files int `json:"files"` Size int64 `json:"size"` } type mimeTypeResponse struct { Mime string `json:"mime"` } type statVFSResponse struct { ID uint32 `json:"-"` Bsize uint64 `json:"bsize"` Frsize uint64 `json:"frsize"` Blocks uint64 `json:"blocks"` Bfree uint64 `json:"bfree"` Bavail uint64 `json:"bavail"` Files uint64 `json:"files"` Ffree uint64 `json:"ffree"` Favail uint64 `json:"favail"` Fsid uint64 `json:"fsid"` Flag uint64 `json:"flag"` Namemax uint64 `json:"namemax"` } func (s *statVFSResponse) toSFTPStatVFS() *sftp.StatVFS { return &sftp.StatVFS{ Bsize: s.Bsize, Frsize: s.Frsize, Blocks: s.Blocks, Bfree: s.Bfree, Bavail: s.Bavail, Files: s.Files, Ffree: s.Ffree, Favail: s.Ffree, Flag: s.Flag, Namemax: s.Namemax, } } ================================================ FILE: internal/vfs/osfs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package vfs import ( "bufio" "errors" "fmt" "io" "io/fs" "net/http" "os" "path" "path/filepath" "slices" "strings" "time" fscopy "github.com/otiai10/copy" "github.com/pkg/sftp" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/logger" ) const ( // osFsName is the name for the local Fs implementation osFsName = "osfs" ) type pathResolutionError struct { err string } func (e *pathResolutionError) Error() string { return fmt.Sprintf("Path resolution error: %s", e.err) } // OsFs is a Fs implementation that uses functions provided by the os package. type OsFs struct { name string connectionID string rootDir string // if not empty this fs is mouted as virtual folder in the specified path mountPath string localTempDir string readBufferSize int writeBufferSize int } // NewOsFs returns an OsFs object that allows to interact with local Os filesystem func NewOsFs(connectionID, rootDir, mountPath string, config *sdk.OSFsConfig) Fs { var readBufferSize, writeBufferSize int if config != nil { readBufferSize = config.ReadBufferSize * 1024 * 1024 writeBufferSize = config.WriteBufferSize * 1024 * 1024 } return &OsFs{ name: osFsName, connectionID: connectionID, rootDir: rootDir, mountPath: getMountPath(mountPath), localTempDir: getLocalTempDir(), readBufferSize: readBufferSize, writeBufferSize: writeBufferSize, } } // Name returns the name for the Fs implementation func (fs *OsFs) Name() string { return fs.name } // ConnectionID returns the SSH connection ID associated to this Fs implementation func (fs *OsFs) ConnectionID() string { return fs.connectionID } // Stat returns a FileInfo describing the named file func (fs *OsFs) Stat(name string) (os.FileInfo, error) { return os.Stat(name) } // Lstat returns a FileInfo describing the named file func (fs *OsFs) Lstat(name string) (os.FileInfo, error) { return os.Lstat(name) } // Open opens the named file for reading func (fs *OsFs) Open(name string, offset int64) (File, PipeReader, func(), error) { f, err := os.Open(name) if err != nil { return nil, nil, nil, err } if offset > 0 { _, err = f.Seek(offset, io.SeekStart) if err != nil { f.Close() return nil, nil, nil, err } } if fs.readBufferSize <= 0 { return f, nil, nil, err } r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { f.Close() return nil, nil, nil, err } p := NewPipeReader(r) go func() { br := bufio.NewReaderSize(f, fs.readBufferSize) n, err := doCopy(w, br, nil) w.CloseWithError(err) //nolint:errcheck f.Close() fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %v", name, n, err) }() return nil, p, nil, nil } // Create creates or opens the named file for writing func (fs *OsFs) Create(name string, flag, _ int) (File, PipeWriter, func(), error) { if !fs.useWriteBuffering(flag) { var err error var f *os.File if flag == 0 { f, err = os.Create(name) } else { f, err = os.OpenFile(name, flag, 0666) } return f, nil, nil, err } f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) if err != nil { return nil, nil, nil, err } r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { f.Close() return nil, nil, nil, err } p := NewPipeWriter(w) go func() { bw := bufio.NewWriterSize(f, fs.writeBufferSize) n, err := doCopy(bw, r, nil) errFlush := bw.Flush() if err == nil && errFlush != nil { err = errFlush } errClose := f.Close() if err == nil && errClose != nil { err = errClose } r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %v, err: %v", name, n, err) }() return nil, p, nil, nil } // Rename renames (moves) source to target func (fs *OsFs) Rename(source, target string, checks int) (int, int64, error) { if source == target { return -1, -1, nil } err := os.Rename(source, target) if err != nil && isCrossDeviceError(err) { fsLog(fs, logger.LevelError, "cross device error detected while renaming %q -> %q. Trying a copy and remove, this could take a long time", source, target) var readBufferSize uint if fs.readBufferSize > 0 { readBufferSize = uint(fs.readBufferSize) } err = fscopy.Copy(source, target, fscopy.Options{ OnSymlink: func(_ string) fscopy.SymlinkAction { return fscopy.Skip }, CopyBufferSize: readBufferSize, }) if err != nil { fsLog(fs, logger.LevelError, "cross device copy error: %v", err) return -1, -1, err } if checks&CheckUpdateModTime != 0 { fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck } err = os.RemoveAll(source) return -1, -1, err } if checks&CheckUpdateModTime != 0 && err == nil { fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck } return -1, -1, err } // Remove removes the named file or (empty) directory. func (*OsFs) Remove(name string, _ bool) error { return os.Remove(name) } // Mkdir creates a new directory with the specified name and default permissions func (*OsFs) Mkdir(name string) error { return os.Mkdir(name, os.ModePerm) } // Symlink creates source as a symbolic link to target. func (*OsFs) Symlink(source, target string) error { return os.Symlink(source, target) } // Readlink returns the destination of the named symbolic link // as absolute virtual path func (fs *OsFs) Readlink(name string) (string, error) { // we don't have to follow multiple links: // https://github.com/openssh/openssh-portable/blob/7bf2eb958fbb551e7d61e75c176bb3200383285d/sftp-server.c#L1329 resolved, err := os.Readlink(name) if err != nil { return "", err } resolved = filepath.Clean(resolved) if !filepath.IsAbs(resolved) { resolved = filepath.Join(filepath.Dir(name), resolved) } return fs.GetRelativePath(resolved), nil } // Chown changes the numeric uid and gid of the named file. func (*OsFs) Chown(name string, uid int, gid int) error { return os.Chown(name, uid, gid) } // Chmod changes the mode of the named file to mode func (*OsFs) Chmod(name string, mode os.FileMode) error { return os.Chmod(name, mode) } // Chtimes changes the access and modification times of the named file func (*OsFs) Chtimes(name string, atime, mtime time.Time, _ bool) error { return os.Chtimes(name, atime, mtime) } // Truncate changes the size of the named file func (*OsFs) Truncate(name string, size int64) error { return os.Truncate(name, size) } // ReadDir reads the directory named by dirname and returns // a list of directory entries. func (*OsFs) ReadDir(dirname string) (DirLister, error) { f, err := os.Open(dirname) if err != nil { if isInvalidNameError(err) { err = os.ErrNotExist } return nil, err } return &osFsDirLister{f}, nil } // IsUploadResumeSupported returns true if resuming uploads is supported func (*OsFs) IsUploadResumeSupported() bool { return true } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (*OsFs) IsConditionalUploadResumeSupported(_ int64) bool { return true } // IsAtomicUploadSupported returns true if atomic upload is supported func (*OsFs) IsAtomicUploadSupported() bool { return true } // IsNotExist returns a boolean indicating whether the error is known to // report that a file or directory does not exist func (*OsFs) IsNotExist(err error) bool { return errors.Is(err, fs.ErrNotExist) } // IsPermission returns a boolean indicating whether the error is known to // report that permission is denied. func (*OsFs) IsPermission(err error) bool { if _, ok := err.(*pathResolutionError); ok { return true } return errors.Is(err, fs.ErrPermission) } // IsNotSupported returns true if the error indicate an unsupported operation func (*OsFs) IsNotSupported(err error) bool { if err == nil { return false } return err == ErrVfsUnsupported } // CheckRootPath creates the root directory if it does not exists func (fs *OsFs) CheckRootPath(username string, uid int, gid int) bool { var err error if _, err = fs.Stat(fs.rootDir); fs.IsNotExist(err) { err = os.MkdirAll(fs.rootDir, os.ModePerm) if err == nil { SetPathPermissions(fs, fs.rootDir, uid, gid) } else { fsLog(fs, logger.LevelError, "error creating root directory %q for user %q: %v", fs.rootDir, username, err) } } return err == nil } // ScanRootDirContents returns the number of files contained in the root // directory and their size func (fs *OsFs) ScanRootDirContents() (int, int64, error) { return fs.GetDirSize(fs.rootDir) } // CheckMetadata checks the metadata consistency func (*OsFs) CheckMetadata() error { return nil } // GetAtomicUploadPath returns the path to use for an atomic upload func (*OsFs) GetAtomicUploadPath(name string) string { dir := filepath.Dir(name) if tempPath != "" { dir = tempPath } guid := xid.New().String() return filepath.Join(dir, ".sftpgo-upload."+guid+"."+filepath.Base(name)) } // GetRelativePath returns the path for a file relative to the user's home dir. // This is the path as seen by SFTPGo users func (fs *OsFs) GetRelativePath(name string) string { virtualPath := "/" if fs.mountPath != "" { virtualPath = fs.mountPath } rel, err := filepath.Rel(fs.rootDir, filepath.Clean(name)) if err != nil { return virtualPath } rel = filepath.ToSlash(rel) if rel == ".." || strings.HasPrefix(rel, "../") { return virtualPath } if rel == "." { rel = "" } return path.Join(virtualPath, rel) } // Walk walks the file tree rooted at root, calling walkFn for each file or // directory in the tree, including root func (*OsFs) Walk(root string, walkFn filepath.WalkFunc) error { return filepath.Walk(root, walkFn) } // Join joins any number of path elements into a single path func (*OsFs) Join(elem ...string) string { return filepath.Join(elem...) } // ResolvePath returns the matching filesystem path for the specified sftp path func (fs *OsFs) ResolvePath(virtualPath string) (string, error) { if !filepath.IsAbs(fs.rootDir) { return "", fmt.Errorf("invalid root path %q", fs.rootDir) } if fs.mountPath != "" { if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { virtualPath = after } } virtualPath = path.Clean("/" + virtualPath) r := filepath.Clean(filepath.Join(fs.rootDir, virtualPath)) p, err := filepath.EvalSymlinks(r) if isInvalidNameError(err) { err = os.ErrNotExist } isNotExist := fs.IsNotExist(err) if err != nil && !isNotExist { return "", err } else if isNotExist { // The requested path doesn't exist, so at this point we need to iterate up the // path chain until we hit a directory that _does_ exist and can be validated. _, err = fs.findFirstExistingDir(r) if err != nil { fsLog(fs, logger.LevelError, "error resolving non-existent path %q", err) } return r, err } err = fs.isSubDir(p) if err != nil { fsLog(fs, logger.LevelError, "Invalid path resolution, path %q original path %q resolved %q err: %v", p, virtualPath, r, err) } return r, err } // RealPath implements the FsRealPather interface func (fs *OsFs) RealPath(p string) (string, error) { linksWalked := 0 for { info, err := os.Lstat(p) if err != nil { if errors.Is(err, os.ErrNotExist) { return fs.GetRelativePath(p), nil } return "", err } if info.Mode()&os.ModeSymlink == 0 { return fs.GetRelativePath(p), nil } resolvedLink, err := os.Readlink(p) if err != nil { return "", err } resolvedLink = filepath.Clean(resolvedLink) if filepath.IsAbs(resolvedLink) { p = resolvedLink } else { p = filepath.Join(filepath.Dir(p), resolvedLink) } linksWalked++ if linksWalked > 10 { fsLog(fs, logger.LevelError, "unable to get real path, too many links: %d", linksWalked) return "", &pathResolutionError{err: "too many links"} } } } // GetDirSize returns the number of files and the size for a folder // including any subfolders func (fs *OsFs) GetDirSize(dirname string) (int, int64, error) { numFiles := 0 size := int64(0) isDir, err := isDirectory(fs, dirname) if err == nil && isDir { err = filepath.Walk(dirname, func(_ string, info os.FileInfo, err error) error { if err != nil { return err } if info != nil && info.Mode().IsRegular() { size += info.Size() numFiles++ if numFiles%1000 == 0 { fsLog(fs, logger.LevelDebug, "dirname %q scan in progress, files: %d, size: %d", dirname, numFiles, size) } } return err }) } return numFiles, size, err } // HasVirtualFolders returns true if folders are emulated func (*OsFs) HasVirtualFolders() bool { return false } func (fs *OsFs) findNonexistentDirs(filePath string) ([]string, error) { results := []string{} cleanPath := filepath.Clean(filePath) parent := filepath.Dir(cleanPath) _, err := os.Stat(parent) for fs.IsNotExist(err) { results = append(results, parent) parent = filepath.Dir(parent) if slices.Contains(results, parent) { break } _, err = os.Stat(parent) } if err != nil { return results, err } p, err := filepath.EvalSymlinks(parent) if err != nil { return results, err } err = fs.isSubDir(p) if err != nil { fsLog(fs, logger.LevelError, "error finding non existing dir: %v", err) } return results, err } func (fs *OsFs) findFirstExistingDir(path string) (string, error) { results, err := fs.findNonexistentDirs(path) if err != nil { fsLog(fs, logger.LevelError, "unable to find non existent dirs: %v", err) return "", err } var parent string if len(results) > 0 { lastMissingDir := results[len(results)-1] parent = filepath.Dir(lastMissingDir) } else { parent = fs.rootDir } p, err := filepath.EvalSymlinks(parent) if err != nil { return "", err } fileInfo, err := os.Stat(p) if err != nil { return "", err } if !fileInfo.IsDir() { return "", fmt.Errorf("resolved path is not a dir: %q", p) } err = fs.isSubDir(p) return p, err } func (fs *OsFs) isSubDir(sub string) error { // fs.rootDir must exist and it is already a validated absolute path parent, err := filepath.EvalSymlinks(fs.rootDir) if err != nil { fsLog(fs, logger.LevelError, "invalid root path %q: %v", fs.rootDir, err) return err } if parent == sub { return nil } if len(sub) < len(parent) { err = fmt.Errorf("path %q is not inside %q", sub, parent) return &pathResolutionError{err: err.Error()} } separator := string(os.PathSeparator) if parent == filepath.Dir(parent) { // parent is the root dir, on Windows we can have C:\, D:\ and so on here // so we still need the prefix check separator = "" } if !strings.HasPrefix(sub, parent+separator) { err = fmt.Errorf("path %q is not inside %q", sub, parent) return &pathResolutionError{err: err.Error()} } return nil } // GetMimeType returns the content type func (fs *OsFs) GetMimeType(name string) (string, error) { f, err := os.OpenFile(name, os.O_RDONLY, 0) if err != nil { return "", err } defer f.Close() var buf [512]byte n, err := io.ReadFull(f, buf[:]) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { return "", err } ctype := http.DetectContentType(buf[:n]) // Rewind file. _, err = f.Seek(0, io.SeekStart) return ctype, err } // Close closes the fs func (*OsFs) Close() error { return nil } // GetAvailableDiskSize returns the available size for the specified path func (*OsFs) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) { return getStatFS(dirName) } func (fs *OsFs) useWriteBuffering(flag int) bool { if fs.writeBufferSize <= 0 { return false } if flag == 0 { return true } if flag&os.O_TRUNC == 0 { fsLog(fs, logger.LevelDebug, "truncate flag missing, buffering write not possible") return false } if flag&os.O_RDWR != 0 { fsLog(fs, logger.LevelDebug, "read and write flag found, buffering write not possible") return false } return true } type osFsDirLister struct { f *os.File } func (l *osFsDirLister) Next(limit int) ([]os.FileInfo, error) { if limit <= 0 { return nil, errInvalidDirListerLimit } return l.f.Readdir(limit) } func (l *osFsDirLister) Close() error { return l.f.Close() } ================================================ FILE: internal/vfs/s3fs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !nos3 package vfs import ( "bytes" "context" "crypto/md5" "crypto/sha256" "crypto/tls" "encoding/base64" "errors" "fmt" "io" "mime" "net" "net/http" "net/url" "os" "path" "path/filepath" "slices" "sort" "strings" "sync" "sync/atomic" "time" "github.com/aws/aws-sdk-go-v2/aws" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/pkg/sftp" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( // using this mime type for directories improves compatibility with s3fs-fuse s3DirMimeType = "application/x-directory" s3TransferBufferSize = 256 * 1024 s3CopyObjectThreshold = 500 * 1024 * 1024 ) var ( s3DirMimeTypes = []string{s3DirMimeType, "httpd/unix-directory"} s3DefaultPageSize = int32(1000) ) // S3Fs is a Fs implementation for AWS S3 compatible object storages type S3Fs struct { connectionID string localTempDir string // if not empty this fs is mouted as virtual folder in the specified path mountPath string config *S3FsConfig svc *s3.Client ctxTimeout time.Duration sseCustomerKey string sseCustomerKeyMD5 string sseCustomerAlgo string } func init() { version.AddFeature("+s3") } // NewS3Fs returns an S3Fs object that allows to interact with an s3 compatible // object storage func NewS3Fs(connectionID, localTempDir, mountPath string, s3Config S3FsConfig) (Fs, error) { if localTempDir == "" { localTempDir = getLocalTempDir() } fs := &S3Fs{ connectionID: connectionID, localTempDir: localTempDir, mountPath: getMountPath(mountPath), config: &s3Config, ctxTimeout: 30 * time.Second, } if err := fs.config.validate(); err != nil { return fs, err } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() awsConfig, err := config.LoadDefaultConfig(ctx, config.WithHTTPClient( getAWSHTTPClient(0, 30*time.Second, fs.config.SkipTLSVerify)), ) if err != nil { return fs, fmt.Errorf("unable to get AWS config: %w", err) } if fs.config.Region != "" { awsConfig.Region = fs.config.Region } if !fs.config.AccessSecret.IsEmpty() { if err := fs.config.AccessSecret.TryDecrypt(); err != nil { return fs, err } awsConfig.Credentials = aws.NewCredentialsCache( credentials.NewStaticCredentialsProvider( fs.config.AccessKey, fs.config.AccessSecret.GetPayload(), fs.config.SessionToken), ) } if !fs.config.SSECustomerKey.IsEmpty() { if err := fs.config.SSECustomerKey.TryDecrypt(); err != nil { return fs, err } key := fs.config.SSECustomerKey.GetPayload() if len(key) == 32 { md5sumBinary := md5.Sum([]byte(key)) fs.sseCustomerKey = base64.StdEncoding.EncodeToString([]byte(key)) fs.sseCustomerKeyMD5 = base64.StdEncoding.EncodeToString(md5sumBinary[:]) } else { keyHash := sha256.Sum256([]byte(key)) md5sumBinary := md5.Sum(keyHash[:]) fs.sseCustomerKey = base64.StdEncoding.EncodeToString(keyHash[:]) fs.sseCustomerKeyMD5 = base64.StdEncoding.EncodeToString(md5sumBinary[:]) } fs.sseCustomerAlgo = "AES256" } fs.setConfigDefaults() if fs.config.RoleARN != "" { client := sts.NewFromConfig(awsConfig) creds := stscreds.NewAssumeRoleProvider(client, fs.config.RoleARN) awsConfig.Credentials = creds } fs.svc = s3.NewFromConfig(awsConfig, func(o *s3.Options) { o.AppID = version.GetVersionHash() o.UsePathStyle = fs.config.ForcePathStyle o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired if fs.config.Endpoint != "" { o.BaseEndpoint = aws.String(fs.config.Endpoint) } }) return fs, nil } // Name returns the name for the Fs implementation func (fs *S3Fs) Name() string { return fmt.Sprintf("%s bucket %q", s3fsName, fs.config.Bucket) } // ConnectionID returns the connection ID associated to this Fs implementation func (fs *S3Fs) ConnectionID() string { return fs.connectionID } // Stat returns a FileInfo describing the named file func (fs *S3Fs) Stat(name string) (os.FileInfo, error) { var result *FileInfo if name == "" || name == "/" || name == "." { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } if fs.config.KeyPrefix == name+"/" { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } obj, err := fs.headObject(name) if err == nil { // Some S3 providers (like SeaweedFS) remove the trailing '/' from object keys. // So we check some common content types to detect if this is a "directory". isDir := slices.Contains(s3DirMimeTypes, util.GetStringFromPointer(obj.ContentType)) if util.GetIntFromPointer(obj.ContentLength) == 0 && !isDir { _, err = fs.headObject(name + "/") isDir = err == nil } info := NewFileInfo(name, isDir, util.GetIntFromPointer(obj.ContentLength), util.GetTimeFromPointer(obj.LastModified), false) return info, nil } if !fs.IsNotExist(err) { return result, err } // now check if this is a prefix (virtual directory) hasContents, err := fs.hasContents(name) if err == nil && hasContents { return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil } else if err != nil { return nil, err } // the requested file may still be a directory as a zero bytes key // with a trailing forward slash (created using mkdir). // S3 doesn't return content type when listing objects, so we have // create "dirs" adding a trailing "/" to the key return fs.getStatForDir(name) } func (fs *S3Fs) getStatForDir(name string) (os.FileInfo, error) { var result *FileInfo obj, err := fs.headObject(name + "/") if err != nil { return result, err } return NewFileInfo(name, true, util.GetIntFromPointer(obj.ContentLength), util.GetTimeFromPointer(obj.LastModified), false), nil } // Lstat returns a FileInfo describing the named file func (fs *S3Fs) Lstat(name string) (os.FileInfo, error) { return fs.Stat(name) } // Open opens the named file for reading func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error) { attrs, err := fs.headObject(name) if err != nil { return nil, nil, nil, err } r, w, err := createPipeFn(fs.localTempDir, fs.config.DownloadPartSize*int64(fs.config.DownloadConcurrency)+1) if err != nil { return nil, nil, nil, err } p := NewPipeReader(r) if readMetadata > 0 { p.setMetadata(attrs.Metadata) } ctx, cancelFn := context.WithCancel(context.Background()) go func() { defer cancelFn() err := fs.handleDownload(ctx, name, offset, w, attrs) w.CloseWithError(err) //nolint:errcheck fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %d, err: %+v", name, w.GetWrittenBytes(), err) metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err) }() return nil, p, cancelFn, nil } // Create creates or opens the named file for writing func (fs *S3Fs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) { if checks&CheckParentDir != 0 { _, err := fs.Stat(path.Dir(name)) if err != nil { return nil, nil, nil, err } } r, w, err := createPipeFn(fs.localTempDir, fs.config.UploadPartSize+1024*1024) if err != nil { return nil, nil, nil, err } var p PipeWriter if checks&CheckResume != 0 { p = newPipeWriterAtOffset(w, 0) } else { p = NewPipeWriter(w) } ctx, cancelFn := context.WithCancel(context.Background()) go func() { defer cancelFn() var contentType string if flag == -1 { contentType = s3DirMimeType } else { contentType = mime.TypeByExtension(path.Ext(name)) } err := fs.handleUpload(ctx, r, name, contentType) r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %q, acl: %q, readed bytes: %d, err: %+v", name, fs.config.ACL, r.GetReadedBytes(), err) metric.S3TransferCompleted(r.GetReadedBytes(), 0, err) }() if checks&CheckResume != 0 { readCh := make(chan error, 1) go func() { n, err := fs.downloadToWriter(name, p) pw := p.(*pipeWriterAtOffset) pw.offset = 0 pw.writeOffset = n readCh <- err }() err = <-readCh if err != nil { cancelFn() p.Close() fsLog(fs, logger.LevelDebug, "download before resume failed, writer closed and read cancelled") return nil, nil, nil, err } } if uploadMode&4 != 0 { return nil, p, nil, nil } return nil, p, cancelFn, nil } // Rename renames (moves) source to target. func (fs *S3Fs) Rename(source, target string, checks int) (int, int64, error) { if source == target { return -1, -1, nil } if checks&CheckParentDir != 0 { _, err := fs.Stat(path.Dir(target)) if err != nil { return -1, -1, err } } fi, err := fs.Stat(source) if err != nil { return -1, -1, err } return fs.renameInternal(source, target, fi, 0, checks&CheckUpdateModTime != 0) } // Remove removes the named file or (empty) directory. func (fs *S3Fs) Remove(name string, isDir bool) error { if isDir { hasContents, err := fs.hasContents(name) if err != nil { return err } if hasContents { return fmt.Errorf("cannot remove non empty directory: %q", name) } if !strings.HasSuffix(name, "/") { name += "/" } } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() _, err := fs.svc.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(name), }) metric.S3DeleteObjectCompleted(err) return err } // Mkdir creates a new directory with the specified name and default permissions func (fs *S3Fs) Mkdir(name string) error { _, err := fs.Stat(name) if !fs.IsNotExist(err) { return err } return fs.mkdirInternal(name) } // Symlink creates source as a symbolic link to target. func (*S3Fs) Symlink(_, _ string) error { return ErrVfsUnsupported } // Readlink returns the destination of the named symbolic link func (*S3Fs) Readlink(_ string) (string, error) { return "", ErrVfsUnsupported } // Chown changes the numeric uid and gid of the named file. func (*S3Fs) Chown(_ string, _ int, _ int) error { return ErrVfsUnsupported } // Chmod changes the mode of the named file to mode. func (*S3Fs) Chmod(_ string, _ os.FileMode) error { return ErrVfsUnsupported } // Chtimes changes the access and modification times of the named file. func (fs *S3Fs) Chtimes(_ string, _, _ time.Time, _ bool) error { return ErrVfsUnsupported } // Truncate changes the size of the named file. // Truncate by path is not supported, while truncating an opened // file is handled inside base transfer func (*S3Fs) Truncate(_ string, _ int64) error { return ErrVfsUnsupported } // ReadDir reads the directory named by dirname and returns // a list of directory entries. func (fs *S3Fs) ReadDir(dirname string) (DirLister, error) { // dirname must be already cleaned prefix := fs.getPrefix(dirname) paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ Bucket: aws.String(fs.config.Bucket), Prefix: aws.String(prefix), Delimiter: aws.String("/"), MaxKeys: &s3DefaultPageSize, }) return &s3DirLister{ paginator: paginator, timeout: fs.ctxTimeout, prefix: prefix, prefixes: make(map[string]bool), }, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. // Resuming uploads is not supported on S3 func (*S3Fs) IsUploadResumeSupported() bool { return false } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (*S3Fs) IsConditionalUploadResumeSupported(size int64) bool { return size <= resumeMaxSize } // IsAtomicUploadSupported returns true if atomic upload is supported. // S3 uploads are already atomic, we don't need to upload to a temporary // file func (*S3Fs) IsAtomicUploadSupported() bool { return false } // IsNotExist returns a boolean indicating whether the error is known to // report that a file or directory does not exist func (*S3Fs) IsNotExist(err error) bool { if err == nil { return false } var re *awshttp.ResponseError if errors.As(err, &re) { if re.Response != nil { return re.Response.StatusCode == http.StatusNotFound } } return false } // IsPermission returns a boolean indicating whether the error is known to // report that permission is denied. func (*S3Fs) IsPermission(err error) bool { if err == nil { return false } var re *awshttp.ResponseError if errors.As(err, &re) { if re.Response != nil { return re.Response.StatusCode == http.StatusForbidden || re.Response.StatusCode == http.StatusUnauthorized } } return false } // IsNotSupported returns true if the error indicate an unsupported operation func (*S3Fs) IsNotSupported(err error) bool { if err == nil { return false } return errors.Is(err, ErrVfsUnsupported) } // CheckRootPath creates the specified local root directory if it does not exists func (fs *S3Fs) CheckRootPath(username string, uid int, gid int) bool { // we need a local directory for temporary files osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) return osFs.CheckRootPath(username, uid, gid) } // ScanRootDirContents returns the number of files contained in the bucket, // and their size func (fs *S3Fs) ScanRootDirContents() (int, int64, error) { return fs.GetDirSize(fs.config.KeyPrefix) } // GetDirSize returns the number of files and the size for a folder // including any subfolders func (fs *S3Fs) GetDirSize(dirname string) (int, int64, error) { prefix := fs.getPrefix(dirname) numFiles := 0 size := int64(0) paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ Bucket: aws.String(fs.config.Bucket), Prefix: aws.String(prefix), MaxKeys: &s3DefaultPageSize, }) for paginator.HasMorePages() { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() page, err := paginator.NextPage(ctx) if err != nil { metric.S3ListObjectsCompleted(err) return numFiles, size, err } for _, fileObject := range page.Contents { isDir := strings.HasSuffix(util.GetStringFromPointer(fileObject.Key), "/") objectSize := util.GetIntFromPointer(fileObject.Size) if isDir && objectSize == 0 { continue } numFiles++ size += objectSize } fsLog(fs, logger.LevelDebug, "scan in progress for %q, files: %d, size: %d", dirname, numFiles, size) } metric.S3ListObjectsCompleted(nil) return numFiles, size, nil } // GetAtomicUploadPath returns the path to use for an atomic upload. // S3 uploads are already atomic, we never call this method for S3 func (*S3Fs) GetAtomicUploadPath(_ string) string { return "" } // GetRelativePath returns the path for a file relative to the user's home dir. // This is the path as seen by SFTPGo users func (fs *S3Fs) GetRelativePath(name string) string { rel := path.Clean(name) if rel == "." { rel = "" } if !path.IsAbs(rel) { rel = "/" + rel } if fs.config.KeyPrefix != "" { if !strings.HasPrefix(rel, "/"+fs.config.KeyPrefix) { rel = "/" } rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) } if fs.mountPath != "" { rel = path.Join(fs.mountPath, rel) } return rel } // Walk walks the file tree rooted at root, calling walkFn for each file or // directory in the tree, including root. The result are unordered func (fs *S3Fs) Walk(root string, walkFn filepath.WalkFunc) error { prefix := fs.getPrefix(root) paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ Bucket: aws.String(fs.config.Bucket), Prefix: aws.String(prefix), MaxKeys: &s3DefaultPageSize, }) for paginator.HasMorePages() { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() page, err := paginator.NextPage(ctx) if err != nil { metric.S3ListObjectsCompleted(err) walkFn(root, NewFileInfo(root, true, 0, time.Unix(0, 0), false), err) //nolint:errcheck return err } for _, fileObject := range page.Contents { name, isDir := fs.resolve(fileObject.Key, prefix) if name == "" { continue } err := walkFn(util.GetStringFromPointer(fileObject.Key), NewFileInfo(name, isDir, util.GetIntFromPointer(fileObject.Size), util.GetTimeFromPointer(fileObject.LastModified), false), nil) if err != nil { return err } } } metric.S3ListObjectsCompleted(nil) walkFn(root, NewFileInfo(root, true, 0, time.Unix(0, 0), false), nil) //nolint:errcheck return nil } // Join joins any number of path elements into a single path func (*S3Fs) Join(elem ...string) string { return strings.TrimPrefix(path.Join(elem...), "/") } // HasVirtualFolders returns true if folders are emulated func (*S3Fs) HasVirtualFolders() bool { return true } // ResolvePath returns the matching filesystem path for the specified virtual path func (fs *S3Fs) ResolvePath(virtualPath string) (string, error) { if fs.mountPath != "" { if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { virtualPath = after } } virtualPath = path.Clean("/" + virtualPath) return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil } // CopyFile implements the FsFileCopier interface func (fs *S3Fs) CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) { numFiles := 1 sizeDiff := srcInfo.Size() attrs, err := fs.headObject(target) if err == nil { sizeDiff -= util.GetIntFromPointer(attrs.ContentLength) numFiles = 0 } else { if !fs.IsNotExist(err) { return 0, 0, err } } if err := fs.copyFileInternal(source, target, srcInfo); err != nil { return 0, 0, err } return numFiles, sizeDiff, nil } func (fs *S3Fs) resolve(name *string, prefix string) (string, bool) { result := strings.TrimPrefix(util.GetStringFromPointer(name), prefix) isDir := strings.HasSuffix(result, "/") if isDir { result = strings.TrimSuffix(result, "/") } return result, isDir } func (fs *S3Fs) setConfigDefaults() { const defaultPartSize = 1024 * 1024 * 5 const defaultConcurrency = 5 if fs.config.UploadPartSize == 0 { fs.config.UploadPartSize = defaultPartSize } else { if fs.config.UploadPartSize < 1024*1024 { fs.config.UploadPartSize *= 1024 * 1024 } } if fs.config.UploadConcurrency == 0 { fs.config.UploadConcurrency = defaultConcurrency } if fs.config.DownloadPartSize == 0 { fs.config.DownloadPartSize = defaultPartSize } else { if fs.config.DownloadPartSize < 1024*1024 { fs.config.DownloadPartSize *= 1024 * 1024 } } if fs.config.DownloadConcurrency == 0 { fs.config.DownloadConcurrency = defaultConcurrency } } func (fs *S3Fs) copyFileInternal(source, target string, srcInfo os.FileInfo) error { contentType := mime.TypeByExtension(path.Ext(source)) copySource := pathEscape(fs.Join(fs.config.Bucket, source)) if srcInfo.Size() > s3CopyObjectThreshold { fsLog(fs, logger.LevelDebug, "renaming file %q with size %d using multipart copy", source, srcInfo.Size()) err := fs.doMultipartCopy(copySource, target, contentType, srcInfo.Size()) metric.S3CopyObjectCompleted(err) return err } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() copyObject := &s3.CopyObjectInput{ Bucket: aws.String(fs.config.Bucket), CopySource: aws.String(copySource), Key: aws.String(target), StorageClass: types.StorageClass(fs.config.StorageClass), ACL: types.ObjectCannedACL(fs.config.ACL), ContentType: util.NilIfEmpty(contentType), CopySourceSSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), CopySourceSSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), CopySourceSSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), } _, err := fs.svc.CopyObject(ctx, copyObject) metric.S3CopyObjectCompleted(err) return err } func (fs *S3Fs) renameInternal(source, target string, srcInfo os.FileInfo, recursion int, updateModTime bool, ) (int, int64, error) { var numFiles int var filesSize int64 if srcInfo.IsDir() { if renameMode == 0 { hasContents, err := fs.hasContents(source) if err != nil { return numFiles, filesSize, err } if hasContents { return numFiles, filesSize, fmt.Errorf("%w: cannot rename non empty directory: %q", ErrVfsUnsupported, source) } } if err := fs.mkdirInternal(target); err != nil { return numFiles, filesSize, err } if renameMode == 1 { files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion, updateModTime) numFiles += files filesSize += size if err != nil { return numFiles, filesSize, err } } } else { if err := fs.copyFileInternal(source, target, srcInfo); err != nil { return numFiles, filesSize, err } numFiles++ filesSize += srcInfo.Size() } err := fs.Remove(source, srcInfo.IsDir()) if fs.IsNotExist(err) { err = nil } return numFiles, filesSize, err } func (fs *S3Fs) mkdirInternal(name string) error { if !strings.HasSuffix(name, "/") { name += "/" } _, w, _, err := fs.Create(name, -1, 0) if err != nil { return err } return w.Close() } func (fs *S3Fs) hasContents(name string) (bool, error) { prefix := fs.getPrefix(name) maxKeys := int32(2) paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ Bucket: aws.String(fs.config.Bucket), Prefix: aws.String(prefix), MaxKeys: &maxKeys, }) if paginator.HasMorePages() { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() page, err := paginator.NextPage(ctx) metric.S3ListObjectsCompleted(err) if err != nil { return false, err } for _, obj := range page.Contents { name, _ := fs.resolve(obj.Key, prefix) if name == "" || name == "/" { continue } return true, nil } return false, nil } metric.S3ListObjectsCompleted(nil) return false, nil } func (fs *S3Fs) downloadPart(ctx context.Context, name string, buf []byte, w io.WriterAt, start, count, writeOffset int64) error { if count == 0 { return nil } rangeHeader := fmt.Sprintf("bytes=%d-%d", start, start+count-1) resp, err := fs.svc.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(name), Range: &rangeHeader, SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), }) if err != nil { return err } defer resp.Body.Close() _, err = io.ReadAtLeast(resp.Body, buf, int(count)) if err != nil { return err } return writeAtFull(w, buf, writeOffset, int(count)) } func (fs *S3Fs) handleDownload(ctx context.Context, name string, offset int64, writer io.WriterAt, attrs *s3.HeadObjectOutput) error { contentLength := util.GetIntFromPointer(attrs.ContentLength) sizeToDownload := contentLength - offset if sizeToDownload < 0 { fsLog(fs, logger.LevelError, "invalid multipart download size or offset, size: %d, offset: %d, size to download: %d", contentLength, offset, sizeToDownload) return errors.New("the requested offset exceeds the file size") } if sizeToDownload == 0 { fsLog(fs, logger.LevelDebug, "nothing to download, offset %d, content length %d", offset, contentLength) return nil } partSize := fs.config.DownloadPartSize guard := make(chan struct{}, fs.config.DownloadConcurrency) var blockCtxTimeout time.Duration if fs.config.DownloadPartMaxTime > 0 { blockCtxTimeout = time.Duration(fs.config.DownloadPartMaxTime) * time.Second } else { blockCtxTimeout = time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute } pool := newBufferAllocator(int(partSize)) defer pool.free() finished := false var wg sync.WaitGroup var errOnce sync.Once var hasError atomic.Bool var poolError error poolCtx, poolCancel := context.WithCancel(ctx) defer poolCancel() for part := 0; !finished; part++ { start := offset end := offset + partSize if end >= contentLength { end = contentLength finished = true } writeOffset := int64(part) * partSize offset = end guard <- struct{}{} if hasError.Load() { fsLog(fs, logger.LevelDebug, "pool error, download for part %d not started", part) break } buf := pool.getBuffer() wg.Add(1) go func(start, end, writeOffset int64, buf []byte) { defer func() { pool.releaseBuffer(buf) <-guard wg.Done() }() innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout)) defer cancelFn() err := fs.downloadPart(innerCtx, name, buf, writer, start, end-start, writeOffset) if err != nil { errOnce.Do(func() { fsLog(fs, logger.LevelError, "multipart download error: %+v", err) hasError.Store(true) poolError = fmt.Errorf("multipart download error: %w", err) poolCancel() }) } }(start, end, writeOffset, buf) } wg.Wait() close(guard) return poolError } func (fs *S3Fs) initiateMultipartUpload(ctx context.Context, name, contentType string) (string, error) { ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout)) defer cancelFn() res, err := fs.svc.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(name), StorageClass: types.StorageClass(fs.config.StorageClass), ACL: types.ObjectCannedACL(fs.config.ACL), ContentType: util.NilIfEmpty(contentType), SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), }) if err != nil { return "", fmt.Errorf("unable to create multipart upload request: %w", err) } uploadID := util.GetStringFromPointer(res.UploadId) if uploadID == "" { return "", errors.New("unable to get multipart upload ID") } return uploadID, nil } func (fs *S3Fs) uploadPart(ctx context.Context, name, uploadID string, partNumber int32, data []byte) (*string, error) { timeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute if fs.config.UploadPartMaxTime > 0 { timeout = time.Duration(fs.config.UploadPartMaxTime) * time.Second } ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(timeout)) defer cancelFn() resp, err := fs.svc.UploadPart(ctx, &s3.UploadPartInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(name), PartNumber: &partNumber, UploadId: aws.String(uploadID), Body: bytes.NewReader(data), SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), }) if err != nil { return nil, fmt.Errorf("unable to upload part number %d: %w", partNumber, err) } return resp.ETag, nil } func (fs *S3Fs) completeMultipartUpload(ctx context.Context, name, uploadID string, completedParts []types.CompletedPart) error { ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout)) defer cancelFn() _, err := fs.svc.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(name), UploadId: aws.String(uploadID), MultipartUpload: &types.CompletedMultipartUpload{ Parts: completedParts, }, }) return err } func (fs *S3Fs) abortMultipartUpload(name, uploadID string) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() _, err := fs.svc.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(name), UploadId: aws.String(uploadID), }) return err } func (fs *S3Fs) singlePartUpload(ctx context.Context, name, contentType string, data []byte) error { timeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute if fs.config.UploadPartMaxTime > 0 { timeout = time.Duration(fs.config.UploadPartMaxTime) * time.Second } ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(timeout)) defer cancelFn() contentLength := int64(len(data)) _, err := fs.svc.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(name), ACL: types.ObjectCannedACL(fs.config.ACL), Body: bytes.NewReader(data), ContentType: util.NilIfEmpty(contentType), ContentLength: &contentLength, SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), StorageClass: types.StorageClass(fs.config.StorageClass), }) return err } func (fs *S3Fs) handleUpload(ctx context.Context, reader io.Reader, name, contentType string) error { pool := newBufferAllocator(int(fs.config.UploadPartSize)) defer pool.free() firstBuf := pool.getBuffer() firstReadSize, err := readFill(reader, firstBuf) if err == io.EOF { return fs.singlePartUpload(ctx, name, contentType, firstBuf[:firstReadSize]) } if err != nil { return err } uploadID, err := fs.initiateMultipartUpload(ctx, name, contentType) if err != nil { return err } guard := make(chan struct{}, fs.config.UploadConcurrency) finished := false var partMutex sync.Mutex var completedParts []types.CompletedPart var wg sync.WaitGroup var hasError atomic.Bool var poolErr error var errOnce sync.Once var partNumber int32 poolCtx, poolCancel := context.WithCancel(ctx) defer poolCancel() finalizeFailedUpload := func(err error) { fsLog(fs, logger.LevelError, "finalize failed multipart upload after error: %v", err) hasError.Store(true) poolErr = err poolCancel() if abortErr := fs.abortMultipartUpload(name, uploadID); abortErr != nil { fsLog(fs, logger.LevelError, "unable to abort multipart upload: %+v", abortErr) } } uploadPart := func(partNum int32, buf []byte, bytesRead int) { defer func() { pool.releaseBuffer(buf) <-guard wg.Done() }() etag, err := fs.uploadPart(poolCtx, name, uploadID, partNum, buf[:bytesRead]) if err != nil { errOnce.Do(func() { finalizeFailedUpload(err) }) return } partMutex.Lock() completedParts = append(completedParts, types.CompletedPart{ PartNumber: &partNum, ETag: etag, }) partMutex.Unlock() } partNumber = 1 guard <- struct{}{} wg.Add(1) go uploadPart(partNumber, firstBuf, firstReadSize) for partNumber = 2; !finished; partNumber++ { buf := pool.getBuffer() n, err := readFill(reader, buf) if err == io.EOF { if n == 0 { pool.releaseBuffer(buf) break } finished = true } else if err != nil { pool.releaseBuffer(buf) errOnce.Do(func() { finalizeFailedUpload(err) }) break } guard <- struct{}{} if hasError.Load() { fsLog(fs, logger.LevelError, "pool error, upload for part %d not started", partNumber) pool.releaseBuffer(buf) break } wg.Add(1) go uploadPart(partNumber, buf, n) } wg.Wait() close(guard) if poolErr != nil { return poolErr } sort.Slice(completedParts, func(i, j int) bool { getPartNumber := func(number *int32) int32 { if number == nil { return 0 } return *number } return getPartNumber(completedParts[i].PartNumber) < getPartNumber(completedParts[j].PartNumber) }) return fs.completeMultipartUpload(ctx, name, uploadID, completedParts) } func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int64) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() res, err := fs.svc.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(target), StorageClass: types.StorageClass(fs.config.StorageClass), ACL: types.ObjectCannedACL(fs.config.ACL), ContentType: util.NilIfEmpty(contentType), SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), }) if err != nil { return fmt.Errorf("unable to create multipart copy request: %w", err) } uploadID := util.GetStringFromPointer(res.UploadId) if uploadID == "" { return errors.New("unable to get multipart copy upload ID") } // We use 32 MB part size and copy 10 parts in parallel. // These values are arbitrary. We don't want to start too many goroutines maxPartSize := int64(32 * 1024 * 1024) if fileSize > int64(100*1024*1024*1024) { maxPartSize = int64(500 * 1024 * 1024) } guard := make(chan struct{}, 10) finished := false var completedParts []types.CompletedPart var partMutex sync.Mutex var wg sync.WaitGroup var hasError atomic.Bool var errOnce sync.Once var copyError error var partNumber int32 var offset int64 opCtx, opCancel := context.WithCancel(context.Background()) defer opCancel() for partNumber = 1; !finished; partNumber++ { start := offset end := offset + maxPartSize if end >= fileSize { end = fileSize finished = true } offset = end guard <- struct{}{} if hasError.Load() { fsLog(fs, logger.LevelDebug, "previous multipart copy error, copy for part %d not started", partNumber) break } wg.Add(1) go func(partNum int32, partStart, partEnd int64) { defer func() { <-guard wg.Done() }() innerCtx, innerCancelFn := context.WithDeadline(opCtx, time.Now().Add(fs.ctxTimeout)) defer innerCancelFn() partResp, err := fs.svc.UploadPartCopy(innerCtx, &s3.UploadPartCopyInput{ Bucket: aws.String(fs.config.Bucket), CopySource: aws.String(source), Key: aws.String(target), PartNumber: &partNum, UploadId: aws.String(uploadID), CopySourceRange: aws.String(fmt.Sprintf("bytes=%d-%d", partStart, partEnd-1)), CopySourceSSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), CopySourceSSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), CopySourceSSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), }) if err != nil { errOnce.Do(func() { fsLog(fs, logger.LevelError, "unable to copy part number %d: %+v", partNum, err) hasError.Store(true) copyError = fmt.Errorf("error copying part number %d: %w", partNum, err) opCancel() if errAbort := fs.abortMultipartUpload(target, uploadID); errAbort != nil { fsLog(fs, logger.LevelError, "unable to abort multipart copy: %+v", errAbort) } }) return } partMutex.Lock() completedParts = append(completedParts, types.CompletedPart{ ETag: partResp.CopyPartResult.ETag, PartNumber: &partNum, }) partMutex.Unlock() }(partNumber, start, end) } wg.Wait() close(guard) if copyError != nil { return copyError } sort.Slice(completedParts, func(i, j int) bool { getPartNumber := func(number *int32) int32 { if number == nil { return 0 } return *number } return getPartNumber(completedParts[i].PartNumber) < getPartNumber(completedParts[j].PartNumber) }) completeCtx, completeCancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer completeCancelFn() _, err = fs.svc.CompleteMultipartUpload(completeCtx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(target), UploadId: aws.String(uploadID), MultipartUpload: &types.CompletedMultipartUpload{ Parts: completedParts, }, }) if err != nil { return fmt.Errorf("unable to complete multipart upload: %w", err) } return nil } func (fs *S3Fs) getPrefix(name string) string { prefix := "" if name != "" && name != "." && name != "/" { prefix = strings.TrimPrefix(name, "/") if !strings.HasSuffix(prefix, "/") { prefix += "/" } } return prefix } func (fs *S3Fs) headObject(name string) (*s3.HeadObjectOutput, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() obj, err := fs.svc.HeadObject(ctx, &s3.HeadObjectInput{ Bucket: aws.String(fs.config.Bucket), Key: aws.String(name), SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), }) metric.S3HeadObjectCompleted(err) return obj, err } // GetMimeType returns the content type func (fs *S3Fs) GetMimeType(name string) (string, error) { obj, err := fs.headObject(name) if err != nil { return "", err } return util.GetStringFromPointer(obj.ContentType), nil } // Close closes the fs func (*S3Fs) Close() error { return nil } // GetAvailableDiskSize returns the available size for the specified path func (*S3Fs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) { return nil, ErrStorageSizeUnavailable } func (fs *S3Fs) downloadToWriter(name string, w PipeWriter) (int64, error) { fsLog(fs, logger.LevelDebug, "starting download before resuming upload, path %q", name) attrs, err := fs.headObject(name) if err != nil { return 0, err } ctx, cancelFn := context.WithTimeout(context.Background(), preResumeTimeout) defer cancelFn() err = fs.handleDownload(ctx, name, 0, w, attrs) fsLog(fs, logger.LevelDebug, "download before resuming upload completed, path %q size: %d, err: %+v", name, w.GetWrittenBytes(), err) metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err) return w.GetWrittenBytes(), err } type s3DirLister struct { baseDirLister paginator *s3.ListObjectsV2Paginator timeout time.Duration prefix string prefixes map[string]bool metricUpdated bool } func (l *s3DirLister) resolve(name *string) (string, bool) { result := strings.TrimPrefix(util.GetStringFromPointer(name), l.prefix) isDir := strings.HasSuffix(result, "/") if isDir { result = strings.TrimSuffix(result, "/") } return result, isDir } func (l *s3DirLister) Next(limit int) ([]os.FileInfo, error) { if limit <= 0 { return nil, errInvalidDirListerLimit } if len(l.cache) >= limit { return l.returnFromCache(limit), nil } if !l.paginator.HasMorePages() { if !l.metricUpdated { l.metricUpdated = true metric.S3ListObjectsCompleted(nil) } return l.returnFromCache(limit), io.EOF } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) defer cancelFn() page, err := l.paginator.NextPage(ctx) if err != nil { metric.S3ListObjectsCompleted(err) return l.cache, err } for _, p := range page.CommonPrefixes { // prefixes have a trailing slash name, _ := l.resolve(p.Prefix) if name == "" { continue } if _, ok := l.prefixes[name]; ok { continue } l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) l.prefixes[name] = true } for _, fileObject := range page.Contents { objectModTime := util.GetTimeFromPointer(fileObject.LastModified) objectSize := util.GetIntFromPointer(fileObject.Size) name, isDir := l.resolve(fileObject.Key) if name == "" || name == "/" { continue } if isDir { if _, ok := l.prefixes[name]; ok { continue } l.prefixes[name] = true } l.cache = append(l.cache, NewFileInfo(name, (isDir && objectSize == 0), objectSize, objectModTime, false)) } return l.returnFromCache(limit), nil } func (l *s3DirLister) Close() error { return l.baseDirLister.Close() } func getAWSHTTPClient(timeout int, idleConnectionTimeout time.Duration, skipTLSVerify bool) *awshttp.BuildableClient { c := awshttp.NewBuildableClient(). WithDialerOptions(func(d *net.Dialer) { d.Timeout = 8 * time.Second }). WithTransportOptions(func(tr *http.Transport) { tr.IdleConnTimeout = idleConnectionTimeout tr.WriteBufferSize = s3TransferBufferSize tr.ReadBufferSize = s3TransferBufferSize if skipTLSVerify { if tr.TLSClientConfig != nil { tr.TLSClientConfig.InsecureSkipVerify = skipTLSVerify } else { tr.TLSClientConfig = &tls.Config{ MinVersion: awshttp.DefaultHTTPTransportTLSMinVersion, InsecureSkipVerify: skipTLSVerify, } } } }) if timeout > 0 { c = c.WithTimeout(time.Duration(timeout) * time.Second) } return c } // ideally we should simply use url.PathEscape: // // https://github.com/awsdocs/aws-doc-sdk-examples/blob/master/go/example_code/s3/s3_copy_object.go#L65 // // but this cause issue with some vendors, see #483, the code below is copied from rclone func pathEscape(in string) string { var u url.URL u.Path = in return strings.ReplaceAll(u.String(), "+", "%2B") } ================================================ FILE: internal/vfs/s3fs_disabled.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build nos3 package vfs import ( "errors" "github.com/drakkan/sftpgo/v2/internal/version" ) func init() { version.AddFeature("-s3") } // NewS3Fs returns an error, S3 is disabled func NewS3Fs(_, _, _ string, _ S3FsConfig) (Fs, error) { return nil, errors.New("S3 disabled at build time") } ================================================ FILE: internal/vfs/sftpfs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package vfs import ( "bufio" "bytes" "crypto/rsa" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "io/fs" "net" "net/http" "os" "path" "path/filepath" "slices" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/pkg/sftp" "github.com/robfig/cron/v3" "github.com/rs/xid" "github.com/sftpgo/sdk" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) const ( // sftpFsName is the name for the SFTP Fs implementation sftpFsName = "sftpfs" logSenderSFTPCache = "sftpCache" maxSessionsPerConnection = 5 ) var ( // ErrSFTPLoop defines the error to return if an SFTP loop is detected ErrSFTPLoop = errors.New("SFTP loop or nested local SFTP folders detected") sftpConnsCache = newSFTPConnectionCache() ) // SFTPFsConfig defines the configuration for SFTP based filesystem type SFTPFsConfig struct { sdk.BaseSFTPFsConfig Password *kms.Secret `json:"password,omitempty"` PrivateKey *kms.Secret `json:"private_key,omitempty"` KeyPassphrase *kms.Secret `json:"key_passphrase,omitempty"` forbiddenSelfUsernames []string `json:"-"` } func (c *SFTPFsConfig) getKeySigner() (ssh.Signer, error) { privPayload := c.PrivateKey.GetPayload() if privPayload == "" { return nil, nil } if key := c.KeyPassphrase.GetPayload(); key != "" { return ssh.ParsePrivateKeyWithPassphrase([]byte(privPayload), []byte(key)) } return ssh.ParsePrivateKey([]byte(privPayload)) } // HideConfidentialData hides confidential data func (c *SFTPFsConfig) HideConfidentialData() { if c.Password != nil { c.Password.Hide() } if c.PrivateKey != nil { c.PrivateKey.Hide() } if c.KeyPassphrase != nil { c.KeyPassphrase.Hide() } } func (c *SFTPFsConfig) setNilSecretsIfEmpty() { if c.Password != nil && c.Password.IsEmpty() { c.Password = nil } if c.PrivateKey != nil && c.PrivateKey.IsEmpty() { c.PrivateKey = nil } if c.KeyPassphrase != nil && c.KeyPassphrase.IsEmpty() { c.KeyPassphrase = nil } } func (c *SFTPFsConfig) isEqual(other SFTPFsConfig) bool { if c.Endpoint != other.Endpoint { return false } if c.Username != other.Username { return false } if c.Prefix != other.Prefix { return false } if c.DisableCouncurrentReads != other.DisableCouncurrentReads { return false } if c.BufferSize != other.BufferSize { return false } if len(c.Fingerprints) != len(other.Fingerprints) { return false } for _, fp := range c.Fingerprints { if !slices.Contains(other.Fingerprints, fp) { return false } } c.setEmptyCredentialsIfNil() other.setEmptyCredentialsIfNil() if !c.Password.IsEqual(other.Password) { return false } if !c.KeyPassphrase.IsEqual(other.KeyPassphrase) { return false } return c.PrivateKey.IsEqual(other.PrivateKey) } func (c *SFTPFsConfig) setEmptyCredentialsIfNil() { if c.Password == nil { c.Password = kms.NewEmptySecret() } if c.PrivateKey == nil { c.PrivateKey = kms.NewEmptySecret() } if c.KeyPassphrase == nil { c.KeyPassphrase = kms.NewEmptySecret() } } func (c *SFTPFsConfig) isSameResource(other SFTPFsConfig) bool { if c.EqualityCheckMode > 0 || other.EqualityCheckMode > 0 { if c.Username != other.Username { return false } } return c.Endpoint == other.Endpoint } // validate returns an error if the configuration is not valid func (c *SFTPFsConfig) validate() error { c.setEmptyCredentialsIfNil() if c.Endpoint == "" { return util.NewI18nError(errors.New("endpoint cannot be empty"), util.I18nErrorEndpointRequired) } if !strings.Contains(c.Endpoint, ":") { c.Endpoint += ":22" } _, _, err := net.SplitHostPort(c.Endpoint) if err != nil { return util.NewI18nError(fmt.Errorf("invalid endpoint: %v", err), util.I18nErrorEndpointInvalid) } if c.Username == "" { return util.NewI18nError(errors.New("username cannot be empty"), util.I18nErrorFsUsernameRequired) } if c.BufferSize < 0 || c.BufferSize > 16 { return errors.New("invalid buffer_size, valid range is 0-16") } if !isEqualityCheckModeValid(c.EqualityCheckMode) { return errors.New("invalid equality_check_mode") } if err := c.validateCredentials(); err != nil { return err } if c.Prefix != "" { c.Prefix = util.CleanPath(c.Prefix) } else { c.Prefix = "/" } return c.validatePrivateKey() } func (c *SFTPFsConfig) validatePrivateKey() error { if c.PrivateKey.IsPlain() { signer, err := c.getKeySigner() if err != nil { return util.NewI18nError(fmt.Errorf("invalid private key: %w", err), util.I18nErrorPrivKeyInvalid) } if signer != nil { if key, ok := signer.PublicKey().(ssh.CryptoPublicKey); ok { cryptoKey := key.CryptoPublicKey() if rsaKey, ok := cryptoKey.(*rsa.PublicKey); ok { if size := rsaKey.N.BitLen(); size < 2048 { return util.NewI18nError( fmt.Errorf("rsa key with size %d not accepted, minimum 2048", size), util.I18nErrorKeySizeInvalid, ) } } } } } return nil } func (c *SFTPFsConfig) validateCredentials() error { if c.Password.IsEmpty() && c.PrivateKey.IsEmpty() { return util.NewI18nError(errors.New("credentials cannot be empty"), util.I18nErrorFsCredentialsRequired) } if c.Password.IsEncrypted() && !c.Password.IsValid() { return errors.New("invalid encrypted password") } if !c.Password.IsEmpty() && !c.Password.IsValidInput() { return errors.New("invalid password") } if c.PrivateKey.IsEncrypted() && !c.PrivateKey.IsValid() { return errors.New("invalid encrypted private key") } if !c.PrivateKey.IsEmpty() && !c.PrivateKey.IsValidInput() { return errors.New("invalid private key") } if c.KeyPassphrase.IsEncrypted() && !c.KeyPassphrase.IsValid() { return errors.New("invalid encrypted private key passphrase") } if !c.KeyPassphrase.IsEmpty() && !c.KeyPassphrase.IsValidInput() { return errors.New("invalid private key passphrase") } return nil } // ValidateAndEncryptCredentials validates the config and encrypts credentials if they are in plain text func (c *SFTPFsConfig) ValidateAndEncryptCredentials(additionalData string) error { if err := c.validate(); err != nil { var errI18n *util.I18nError errValidation := util.NewValidationError(fmt.Sprintf("could not validate SFTP fs config: %v", err)) if errors.As(err, &errI18n) { return util.NewI18nError(errValidation, errI18n.Message) } return util.NewI18nError(errValidation, util.I18nErrorFsValidation) } if c.Password.IsPlain() { c.Password.SetAdditionalData(additionalData) if err := c.Password.Encrypt(); err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs password: %v", err)), util.I18nErrorFsValidation, ) } } if c.PrivateKey.IsPlain() { c.PrivateKey.SetAdditionalData(additionalData) if err := c.PrivateKey.Encrypt(); err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs private key: %v", err)), util.I18nErrorFsValidation, ) } } if c.KeyPassphrase.IsPlain() { c.KeyPassphrase.SetAdditionalData(additionalData) if err := c.KeyPassphrase.Encrypt(); err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs private key passphrase: %v", err)), util.I18nErrorFsValidation, ) } } return nil } // getUniqueID returns an hash of the settings used to connect to the SFTP server func (c *SFTPFsConfig) getUniqueID(partition int) string { h := sha256.New() var b bytes.Buffer b.WriteString(c.Endpoint) b.WriteString(c.Username) b.WriteString(strings.Join(c.Fingerprints, "")) b.WriteString(strconv.FormatBool(c.DisableCouncurrentReads)) b.WriteString(strconv.FormatInt(c.BufferSize, 10)) b.WriteString(c.Password.GetPayload()) b.WriteString(c.PrivateKey.GetPayload()) b.WriteString(c.KeyPassphrase.GetPayload()) if allowSelfConnections != 0 { b.WriteString(strings.Join(c.forbiddenSelfUsernames, "")) } b.WriteString(strconv.Itoa(partition)) h.Write(b.Bytes()) return hex.EncodeToString(h.Sum(nil)) } // SFTPFs is a Fs implementation for SFTP backends type SFTPFs struct { connectionID string // if not empty this fs is mouted as virtual folder in the specified path mountPath string localTempDir string config *SFTPFsConfig conn *sftpConnection } // NewSFTPFs returns an SFTPFs object that allows to interact with an SFTP server func NewSFTPFs(connectionID, mountPath, localTempDir string, forbiddenSelfUsernames []string, config SFTPFsConfig) (Fs, error) { if localTempDir == "" { localTempDir = getLocalTempDir() } if err := config.validate(); err != nil { return nil, err } if !config.Password.IsEmpty() { if err := config.Password.TryDecrypt(); err != nil { return nil, err } } if !config.PrivateKey.IsEmpty() { if err := config.PrivateKey.TryDecrypt(); err != nil { return nil, err } } if !config.KeyPassphrase.IsEmpty() { if err := config.KeyPassphrase.TryDecrypt(); err != nil { return nil, err } } conn, err := sftpConnsCache.Get(&config, connectionID) if err != nil { return nil, err } config.forbiddenSelfUsernames = forbiddenSelfUsernames sftpFs := &SFTPFs{ connectionID: connectionID, mountPath: getMountPath(mountPath), localTempDir: localTempDir, config: &config, conn: conn, } err = sftpFs.createConnection() if err != nil { sftpFs.Close() //nolint:errcheck } return sftpFs, err } // Name returns the name for the Fs implementation func (fs *SFTPFs) Name() string { return fmt.Sprintf(`%s %q@%q`, sftpFsName, fs.config.Username, fs.config.Endpoint) } // ConnectionID returns the connection ID associated to this Fs implementation func (fs *SFTPFs) ConnectionID() string { return fs.connectionID } // Stat returns a FileInfo describing the named file func (fs *SFTPFs) Stat(name string) (os.FileInfo, error) { client, err := fs.conn.getClient() if err != nil { return nil, err } return client.Stat(name) } // Lstat returns a FileInfo describing the named file func (fs *SFTPFs) Lstat(name string) (os.FileInfo, error) { client, err := fs.conn.getClient() if err != nil { return nil, err } return client.Lstat(name) } // Open opens the named file for reading func (fs *SFTPFs) Open(name string, offset int64) (File, PipeReader, func(), error) { client, err := fs.conn.getClient() if err != nil { return nil, nil, nil, err } f, err := client.Open(name) if err != nil { return nil, nil, nil, err } if offset > 0 { _, err = f.Seek(offset, io.SeekStart) if err != nil { f.Close() return nil, nil, nil, err } } if fs.config.BufferSize == 0 { return f, nil, nil, nil } r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { f.Close() return nil, nil, nil, err } p := NewPipeReader(r) go func() { // if we enable buffering the client stalls //br := bufio.NewReaderSize(f, int(fs.config.BufferSize)*1024*1024) //n, err := fs.copy(w, br) n, err := io.Copy(w, f) w.CloseWithError(err) //nolint:errcheck f.Close() fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %v", name, n, err) }() return nil, p, nil, nil } // Create creates or opens the named file for writing func (fs *SFTPFs) Create(name string, flag, _ int) (File, PipeWriter, func(), error) { client, err := fs.conn.getClient() if err != nil { return nil, nil, nil, err } if fs.config.BufferSize == 0 { var f File if flag == 0 { f, err = client.Create(name) } else { f, err = client.OpenFile(name, flag) } return f, nil, nil, err } // buffering is enabled f, err := client.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC) if err != nil { return nil, nil, nil, err } r, w, err := createPipeFn(fs.localTempDir, 0) if err != nil { f.Close() return nil, nil, nil, err } p := NewPipeWriter(w) go func() { bw := bufio.NewWriterSize(f, int(fs.config.BufferSize)*1024*1024) // we don't use io.Copy since bufio.Writer implements io.WriterTo and // so it calls the sftp.File WriteTo method without buffering n, err := doCopy(bw, r, nil) errFlush := bw.Flush() if err == nil && errFlush != nil { err = errFlush } var errTruncate error if err != nil { errTruncate = f.Truncate(n) } errClose := f.Close() if err == nil && errClose != nil { err = errClose } r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %v, err: %v err truncate: %v", name, n, err, errTruncate) }() return nil, p, nil, nil } // Rename renames (moves) source to target. func (fs *SFTPFs) Rename(source, target string, checks int) (int, int64, error) { if source == target { return -1, -1, nil } client, err := fs.conn.getClient() if err != nil { return -1, -1, err } if _, ok := client.HasExtension("posix-rename@openssh.com"); ok { err := client.PosixRename(source, target) if checks&CheckUpdateModTime != 0 && err == nil { fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck } return -1, -1, err } err = client.Rename(source, target) if checks&CheckUpdateModTime != 0 && err == nil { fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck } return -1, -1, err } // Remove removes the named file or (empty) directory. func (fs *SFTPFs) Remove(name string, isDir bool) error { client, err := fs.conn.getClient() if err != nil { return err } if isDir { return client.RemoveDirectory(name) } return client.Remove(name) } // Mkdir creates a new directory with the specified name and default permissions func (fs *SFTPFs) Mkdir(name string) error { client, err := fs.conn.getClient() if err != nil { return err } return client.Mkdir(name) } // Symlink creates source as a symbolic link to target. func (fs *SFTPFs) Symlink(source, target string) error { client, err := fs.conn.getClient() if err != nil { return err } return client.Symlink(source, target) } // Readlink returns the destination of the named symbolic link func (fs *SFTPFs) Readlink(name string) (string, error) { client, err := fs.conn.getClient() if err != nil { return "", err } resolved, err := client.ReadLink(name) if err != nil { return resolved, err } resolved = path.Clean(strings.ReplaceAll(resolved, "\\", "/")) if !path.IsAbs(resolved) { // we assume that multiple links are not followed resolved = path.Join(path.Dir(name), resolved) } return fs.GetRelativePath(resolved), nil } // Chown changes the numeric uid and gid of the named file. func (fs *SFTPFs) Chown(name string, uid int, gid int) error { client, err := fs.conn.getClient() if err != nil { return err } return client.Chown(name, uid, gid) } // Chmod changes the mode of the named file to mode. func (fs *SFTPFs) Chmod(name string, mode os.FileMode) error { client, err := fs.conn.getClient() if err != nil { return err } return client.Chmod(name, mode) } // Chtimes changes the access and modification times of the named file. func (fs *SFTPFs) Chtimes(name string, atime, mtime time.Time, _ bool) error { client, err := fs.conn.getClient() if err != nil { return err } return client.Chtimes(name, atime, mtime) } // Truncate changes the size of the named file. func (fs *SFTPFs) Truncate(name string, size int64) error { client, err := fs.conn.getClient() if err != nil { return err } return client.Truncate(name, size) } // ReadDir reads the directory named by dirname and returns // a list of directory entries. func (fs *SFTPFs) ReadDir(dirname string) (DirLister, error) { client, err := fs.conn.getClient() if err != nil { return nil, err } files, err := client.ReadDir(dirname) if err != nil { return nil, err } return &baseDirLister{files}, nil } // IsUploadResumeSupported returns true if resuming uploads is supported. func (fs *SFTPFs) IsUploadResumeSupported() bool { return fs.config.BufferSize == 0 } // IsConditionalUploadResumeSupported returns if resuming uploads is supported // for the specified size func (fs *SFTPFs) IsConditionalUploadResumeSupported(_ int64) bool { return fs.IsUploadResumeSupported() } // IsAtomicUploadSupported returns true if atomic upload is supported. func (fs *SFTPFs) IsAtomicUploadSupported() bool { return fs.config.BufferSize == 0 } // IsNotExist returns a boolean indicating whether the error is known to // report that a file or directory does not exist func (*SFTPFs) IsNotExist(err error) bool { return errors.Is(err, fs.ErrNotExist) } // IsPermission returns a boolean indicating whether the error is known to // report that permission is denied. func (*SFTPFs) IsPermission(err error) bool { if _, ok := err.(*pathResolutionError); ok { return true } return errors.Is(err, fs.ErrPermission) } // IsNotSupported returns true if the error indicate an unsupported operation func (*SFTPFs) IsNotSupported(err error) bool { if err == nil { return false } return err == ErrVfsUnsupported } // CheckRootPath creates the specified local root directory if it does not exists func (fs *SFTPFs) CheckRootPath(username string, uid int, gid int) bool { // local directory for temporary files in buffer mode osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) osFs.CheckRootPath(username, uid, gid) if fs.config.Prefix == "/" { return true } client, err := fs.conn.getClient() if err != nil { return false } if err := client.MkdirAll(fs.config.Prefix); err != nil { fsLog(fs, logger.LevelDebug, "error creating root directory %q for user %q: %v", fs.config.Prefix, username, err) return false } return true } // ScanRootDirContents returns the number of files contained in a directory and // their size func (fs *SFTPFs) ScanRootDirContents() (int, int64, error) { return fs.GetDirSize(fs.config.Prefix) } // CheckMetadata checks the metadata consistency func (*SFTPFs) CheckMetadata() error { return nil } // GetAtomicUploadPath returns the path to use for an atomic upload func (*SFTPFs) GetAtomicUploadPath(name string) string { dir := path.Dir(name) guid := xid.New().String() return path.Join(dir, ".sftpgo-upload."+guid+"."+path.Base(name)) } // GetRelativePath returns the path for a file relative to the sftp prefix if any. // This is the path as seen by SFTPGo users func (fs *SFTPFs) GetRelativePath(name string) string { rel := path.Clean(name) if rel == "." { rel = "" } if !path.IsAbs(rel) { // If we have a relative path we assume it is already relative to the virtual root rel = "/" + rel } else if fs.config.Prefix != "/" { prefixDir := fs.config.Prefix if !strings.HasSuffix(prefixDir, "/") { prefixDir += "/" } if rel == fs.config.Prefix { rel = "/" } else if after, found := strings.CutPrefix(rel, prefixDir); found { rel = path.Clean("/" + after) } else { // Absolute path outside of the configured prefix fsLog(fs, logger.LevelWarn, "path %q is an absolute path outside %q", name, fs.config.Prefix) rel = "/" } } if fs.mountPath != "" { rel = path.Join(fs.mountPath, rel) } return rel } // Walk walks the file tree rooted at root, calling walkFn for each file or // directory in the tree, including root func (fs *SFTPFs) Walk(root string, walkFn filepath.WalkFunc) error { client, err := fs.conn.getClient() if err != nil { return err } walker := client.Walk(root) for walker.Step() { err := walker.Err() if err != nil { return err } err = walkFn(walker.Path(), walker.Stat(), err) if err != nil { return err } } return nil } // Join joins any number of path elements into a single path func (*SFTPFs) Join(elem ...string) string { return path.Join(elem...) } // HasVirtualFolders returns true if folders are emulated func (*SFTPFs) HasVirtualFolders() bool { return false } // ResolvePath returns the matching filesystem path for the specified virtual path func (fs *SFTPFs) ResolvePath(virtualPath string) (string, error) { if fs.mountPath != "" { if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { virtualPath = after } } virtualPath = path.Clean("/" + virtualPath) fsPath := fs.Join(fs.config.Prefix, virtualPath) if fs.config.Prefix != "/" && fsPath != "/" { // we need to check if this path is a symlink outside the given prefix // or a file/dir inside a dir symlinked outside the prefix var validatedPath string var err error validatedPath, err = fs.getRealPath(fsPath) isNotExist := fs.IsNotExist(err) if err != nil && !isNotExist { fsLog(fs, logger.LevelError, "Invalid path resolution, original path %v resolved %q err: %v", virtualPath, fsPath, err) return "", err } else if isNotExist { for fs.IsNotExist(err) { validatedPath = path.Dir(validatedPath) if validatedPath == "/" { err = nil break } validatedPath, err = fs.getRealPath(validatedPath) } if err != nil { fsLog(fs, logger.LevelError, "Invalid path resolution, dir %q original path %q resolved %q err: %v", validatedPath, virtualPath, fsPath, err) return "", err } } if err := fs.isSubDir(validatedPath); err != nil { fsLog(fs, logger.LevelError, "Invalid path resolution, dir %q original path %q resolved %q err: %v", validatedPath, virtualPath, fsPath, err) return "", err } } return fsPath, nil } // RealPath implements the FsRealPather interface func (fs *SFTPFs) RealPath(p string) (string, error) { client, err := fs.conn.getClient() if err != nil { return "", err } resolved, err := client.RealPath(p) if err != nil { return "", err } resolved = path.Clean(strings.ReplaceAll(resolved, "\\", "/")) if fs.config.Prefix != "/" { if err := fs.isSubDir(resolved); err != nil { fsLog(fs, logger.LevelError, "Invalid real path resolution, original path %q resolved %q err: %v", p, resolved, err) return "", err } } return fs.GetRelativePath(resolved), nil } // getRealPath returns the real remote path trying to resolve symbolic links if any func (fs *SFTPFs) getRealPath(name string) (string, error) { client, err := fs.conn.getClient() if err != nil { return "", err } linksWalked := 0 for { info, err := client.Lstat(name) if err != nil { return name, err } if info.Mode()&os.ModeSymlink == 0 { return name, nil } resolvedLink, err := client.ReadLink(name) if err != nil { return name, fmt.Errorf("unable to resolve link to %q: %w", name, err) } resolvedLink = strings.ReplaceAll(resolvedLink, "\\", "/") resolvedLink = path.Clean(resolvedLink) if path.IsAbs(resolvedLink) { name = resolvedLink } else { name = path.Join(path.Dir(name), resolvedLink) } linksWalked++ if linksWalked > 10 { fsLog(fs, logger.LevelError, "unable to get real path, too many links: %d", linksWalked) return "", &pathResolutionError{err: "too many links"} } } } func (fs *SFTPFs) isSubDir(name string) error { if name == fs.config.Prefix { return nil } if len(name) < len(fs.config.Prefix) { err := fmt.Errorf("path %q is not inside: %q", name, fs.config.Prefix) return &pathResolutionError{err: err.Error()} } if !strings.HasPrefix(name, fs.config.Prefix+"/") { err := fmt.Errorf("path %q is not inside: %q", name, fs.config.Prefix) return &pathResolutionError{err: err.Error()} } return nil } // GetDirSize returns the number of files and the size for a folder // including any subfolders func (fs *SFTPFs) GetDirSize(dirname string) (int, int64, error) { numFiles := 0 size := int64(0) client, err := fs.conn.getClient() if err != nil { return numFiles, size, err } isDir, err := isDirectory(fs, dirname) if err == nil && isDir { walker := client.Walk(dirname) for walker.Step() { err := walker.Err() if err != nil { return numFiles, size, err } if walker.Stat().Mode().IsRegular() { size += walker.Stat().Size() numFiles++ if numFiles%1000 == 0 { fsLog(fs, logger.LevelDebug, "dirname %q scan in progress, files: %d, size: %d", dirname, numFiles, size) } } } } return numFiles, size, err } // GetMimeType returns the content type func (fs *SFTPFs) GetMimeType(name string) (string, error) { client, err := fs.conn.getClient() if err != nil { return "", err } f, err := client.OpenFile(name, os.O_RDONLY) if err != nil { return "", err } defer f.Close() var buf [512]byte n, err := io.ReadFull(f, buf[:]) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { return "", err } ctype := http.DetectContentType(buf[:n]) // Rewind file. _, err = f.Seek(0, io.SeekStart) return ctype, err } // GetAvailableDiskSize returns the available size for the specified path func (fs *SFTPFs) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) { client, err := fs.conn.getClient() if err != nil { return nil, err } if _, ok := client.HasExtension("statvfs@openssh.com"); !ok { return nil, ErrStorageSizeUnavailable } return client.StatVFS(dirName) } // Close the connection func (fs *SFTPFs) Close() error { fs.conn.RemoveSession(fs.connectionID) return nil } func (fs *SFTPFs) createConnection() error { err := fs.conn.OpenConnection() if err != nil { fsLog(fs, logger.LevelError, "error opening connection: %v", err) return err } return nil } type sftpConnection struct { config *SFTPFsConfig logSender string sshClient *ssh.Client sftpClient *sftp.Client mu sync.RWMutex isConnected bool sessions map[string]bool lastActivity time.Time signer ssh.Signer } func newSFTPConnection(config *SFTPFsConfig, sessionID string) *sftpConnection { c := &sftpConnection{ config: config, logSender: fmt.Sprintf(`%s "%s@%s"`, sftpFsName, config.Username, config.Endpoint), isConnected: false, sessions: map[string]bool{}, lastActivity: time.Now().UTC(), signer: nil, } c.sessions[sessionID] = true return c } func (c *sftpConnection) OpenConnection() error { c.mu.Lock() defer c.mu.Unlock() return c.openConnNoLock() } func (c *sftpConnection) openConnNoLock() error { if c.isConnected { logger.Debug(c.logSender, "", "reusing connection") return nil } logger.Debug(c.logSender, "", "try to open a new connection") clientConfig := &ssh.ClientConfig{ User: c.config.Username, HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error { fp := ssh.FingerprintSHA256(key) if slices.Contains(sftpFingerprints, fp) { if allowSelfConnections == 0 { logger.Log(logger.LevelError, c.logSender, "", "SFTP self connections not allowed") return ErrSFTPLoop } if slices.Contains(c.config.forbiddenSelfUsernames, c.config.Username) { logger.Log(logger.LevelError, c.logSender, "", "SFTP loop or nested local SFTP folders detected, username %q, forbidden usernames: %+v", c.config.Username, c.config.forbiddenSelfUsernames) return ErrSFTPLoop } } if len(c.config.Fingerprints) > 0 { for _, provided := range c.config.Fingerprints { if provided == fp { return nil } } return fmt.Errorf("invalid fingerprint %q", fp) } logger.Log(logger.LevelWarn, c.logSender, "", "login without host key validation, please provide at least a fingerprint!") return nil }, Timeout: 15 * time.Second, ClientVersion: fmt.Sprintf("SSH-2.0-%s", version.GetServerVersion("_", false)), } if c.signer != nil { clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(c.signer)) } if pwd := c.config.Password.GetPayload(); pwd != "" { clientConfig.Auth = append(clientConfig.Auth, ssh.Password(pwd)) } supportedAlgos := ssh.SupportedAlgorithms() insecureAlgos := ssh.InsecureAlgorithms() // add all available ciphers, KEXs and MACs, they are negotiated according to the order clientConfig.Ciphers = append(supportedAlgos.Ciphers, ssh.InsecureCipherAES128CBC) clientConfig.KeyExchanges = append(supportedAlgos.KeyExchanges, insecureAlgos.KeyExchanges...) clientConfig.MACs = append(supportedAlgos.MACs, insecureAlgos.MACs...) sshClient, err := ssh.Dial("tcp", c.config.Endpoint, clientConfig) if err != nil { return fmt.Errorf("sftpfs: unable to connect: %w", err) } sftpClient, err := sftp.NewClient(sshClient, c.getClientOptions()...) if err != nil { sshClient.Close() return fmt.Errorf("sftpfs: unable to create SFTP client: %w", err) } c.sshClient = sshClient c.sftpClient = sftpClient c.isConnected = true go c.Wait() return nil } func (c *sftpConnection) getClientOptions() []sftp.ClientOption { var options []sftp.ClientOption if c.config.DisableCouncurrentReads { options = append(options, sftp.UseConcurrentReads(false)) logger.Debug(c.logSender, "", "disabling concurrent reads") } if c.config.BufferSize > 0 { options = append(options, sftp.UseConcurrentWrites(true)) logger.Debug(c.logSender, "", "enabling concurrent writes") } return options } func (c *sftpConnection) getClient() (*sftp.Client, error) { c.mu.Lock() defer c.mu.Unlock() if c.isConnected { return c.sftpClient, nil } err := c.openConnNoLock() return c.sftpClient, err } func (c *sftpConnection) Wait() { done := make(chan struct{}) go func() { var watchdogInProgress atomic.Bool ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: if watchdogInProgress.Load() { logger.Error(c.logSender, "", "watchdog still in progress, closing hanging connection") c.sshClient.Close() return } go func() { watchdogInProgress.Store(true) defer watchdogInProgress.Store(false) _, err := c.sftpClient.Getwd() if err != nil { logger.Error(c.logSender, "", "watchdog error: %v", err) } }() case <-done: logger.Debug(c.logSender, "", "quitting watchdog") return } } }() // we wait on the sftp client otherwise if the channel is closed but not the connection // we don't detect the event. err := c.sftpClient.Wait() logger.Log(logger.LevelDebug, c.logSender, "", "sftp channel closed: %v", err) close(done) c.mu.Lock() defer c.mu.Unlock() c.isConnected = false if c.sshClient != nil { c.sshClient.Close() } } func (c *sftpConnection) Close() error { c.mu.Lock() defer c.mu.Unlock() logger.Debug(c.logSender, "", "closing connection") var sftpErr, sshErr error if c.sftpClient != nil { sftpErr = c.sftpClient.Close() } if c.sshClient != nil { sshErr = c.sshClient.Close() } if sftpErr != nil { return sftpErr } c.isConnected = false return sshErr } func (c *sftpConnection) AddSession(sessionID string) { c.mu.Lock() defer c.mu.Unlock() c.sessions[sessionID] = true logger.Debug(c.logSender, "", "added session %s, active sessions: %d", sessionID, len(c.sessions)) } func (c *sftpConnection) RemoveSession(sessionID string) { c.mu.Lock() defer c.mu.Unlock() delete(c.sessions, sessionID) logger.Debug(c.logSender, "", "removed session %s, active sessions: %d", sessionID, len(c.sessions)) if len(c.sessions) == 0 { c.lastActivity = time.Now().UTC() } } func (c *sftpConnection) ActiveSessions() int { c.mu.RLock() defer c.mu.RUnlock() return len(c.sessions) } func (c *sftpConnection) GetLastActivity() time.Time { c.mu.RLock() defer c.mu.RUnlock() if len(c.sessions) > 0 { return time.Now().UTC() } logger.Debug(c.logSender, "", "last activity %s", c.lastActivity) return c.lastActivity } type sftpConnectionsCache struct { scheduler *cron.Cron sync.Mutex items map[string]*sftpConnection } func newSFTPConnectionCache() *sftpConnectionsCache { c := &sftpConnectionsCache{ scheduler: cron.New(cron.WithLocation(time.UTC), cron.WithLogger(cron.DiscardLogger)), items: make(map[string]*sftpConnection), } _, err := c.scheduler.AddFunc("@every 1m", c.Cleanup) util.PanicOnError(err) c.scheduler.Start() return c } func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) (*sftpConnection, error) { partition := 0 key := config.getUniqueID(partition) c.Lock() defer c.Unlock() for { if val, ok := c.items[key]; ok { activeSessions := val.ActiveSessions() if activeSessions < maxSessionsPerConnection { logger.Debug(logSenderSFTPCache, "", "reusing connection for session ID %q, key %s, active sessions %d, active connections: %d", sessionID, key, activeSessions+1, len(c.items)) val.AddSession(sessionID) return val, nil } partition++ key = config.getUniqueID(partition) logger.Debug(logSenderSFTPCache, "", "connection full, generated new key for partition: %d, active sessions: %d, key: %s", partition, activeSessions, key) } else { conn := newSFTPConnection(config, sessionID) signer, err := config.getKeySigner() if err != nil { return nil, fmt.Errorf("sftpfs: unable to parse the private key: %w", err) } conn.signer = signer c.items[key] = conn logger.Debug(logSenderSFTPCache, "", "adding new connection for session ID %q, partition: %d, key: %s, active connections: %d", sessionID, partition, key, len(c.items)) return conn, nil } } } func (c *sftpConnectionsCache) Cleanup() { c.Lock() var connectionsToClose []*sftpConnection for k, conn := range c.items { if val := conn.GetLastActivity(); val.Before(time.Now().Add(-30 * time.Second)) { delete(c.items, k) logger.Debug(logSenderSFTPCache, "", "removed connection with key %s, last activity %s, active connections: %d", k, val, len(c.items)) connectionsToClose = append(connectionsToClose, conn) } } c.Unlock() for _, conn := range connectionsToClose { err := conn.Close() logger.Debug(logSenderSFTPCache, "", "connection closed, err: %v", err) } } ================================================ FILE: internal/vfs/statvfs_fallback.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !darwin && !linux && !freebsd package vfs import ( "github.com/pkg/sftp" "github.com/shirou/gopsutil/v3/disk" ) const bsize = uint64(4096) func getStatFS(path string) (*sftp.StatVFS, error) { usage, err := disk.Usage(path) if err != nil { return nil, err } // we assume block size = 4096 blocks := usage.Total / bsize bfree := usage.Free / bsize files := usage.InodesTotal ffree := usage.InodesFree if files == 0 { // these assumptions are wrong but still better than returning 0 files = blocks / 4 ffree = bfree / 4 } return &sftp.StatVFS{ Bsize: bsize, Frsize: bsize, Blocks: blocks, Bfree: bfree, Bavail: bfree, Files: files, Ffree: ffree, Favail: ffree, Namemax: 255, }, nil } ================================================ FILE: internal/vfs/statvfs_linux.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build linux package vfs import ( "github.com/pkg/sftp" "golang.org/x/sys/unix" ) func getStatFS(path string) (*sftp.StatVFS, error) { stat := unix.Statfs_t{} err := unix.Statfs(path, &stat) if err != nil { return nil, err } return &sftp.StatVFS{ Bsize: uint64(stat.Bsize), Frsize: uint64(stat.Frsize), Blocks: stat.Blocks, Bfree: stat.Bfree, Bavail: stat.Bavail, Files: stat.Files, Ffree: stat.Ffree, Favail: stat.Ffree, // not sure how to calculate Favail Flag: uint64(stat.Flags), Namemax: uint64(stat.Namelen), }, nil } ================================================ FILE: internal/vfs/statvfs_unix.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build freebsd || darwin package vfs import ( "github.com/pkg/sftp" "golang.org/x/sys/unix" ) func getStatFS(path string) (*sftp.StatVFS, error) { stat := unix.Statfs_t{} err := unix.Statfs(path, &stat) if err != nil { return nil, err } return &sftp.StatVFS{ Bsize: uint64(stat.Bsize), Frsize: uint64(stat.Bsize), Blocks: stat.Blocks, Bfree: stat.Bfree, Bavail: uint64(stat.Bavail), Files: stat.Files, Ffree: uint64(stat.Ffree), Favail: uint64(stat.Ffree), // not sure how to calculate Favail Flag: uint64(stat.Flags), Namemax: 255, // we use a conservative value here }, nil } ================================================ FILE: internal/vfs/sys_unix.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . //go:build !windows package vfs import ( "errors" "golang.org/x/sys/unix" ) func isCrossDeviceError(err error) bool { return errors.Is(err, unix.EXDEV) } func isInvalidNameError(_ error) bool { return false } ================================================ FILE: internal/vfs/sys_windows.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package vfs import ( "errors" "golang.org/x/sys/windows" ) func isCrossDeviceError(err error) bool { return errors.Is(err, windows.ERROR_NOT_SAME_DEVICE) } func isInvalidNameError(err error) bool { if err == nil { return false } return errors.Is(err, windows.ERROR_INVALID_NAME) } ================================================ FILE: internal/vfs/vfs.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package vfs provides local and remote filesystems support package vfs import ( "bytes" "errors" "fmt" "io" "net/url" "os" "path" "path/filepath" "runtime" "slices" "strconv" "strings" "sync" "time" "github.com/eikenb/pipeat" "github.com/pkg/sftp" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) const ( dirMimeType = "inode/directory" s3fsName = "S3Fs" gcsfsName = "GCSFs" azBlobFsName = "AzureBlobFs" lastModifiedField = "sftpgo_last_modified" preResumeTimeout = 90 * time.Second // ListerBatchSize defines the default limit for DirLister implementations ListerBatchSize = 1000 ) // Additional checks for files const ( CheckParentDir = 1 CheckResume = 2 CheckUpdateModTime = 4 ) var ( validAzAccessTier = []string{"", "Archive", "Hot", "Cool"} // ErrStorageSizeUnavailable is returned if the storage backend does not support getting the size ErrStorageSizeUnavailable = errors.New("unable to get available size for this storage backend") // ErrVfsUnsupported defines the error for an unsupported VFS operation ErrVfsUnsupported = errors.New("not supported") errInvalidDirListerLimit = errors.New("dir lister: invalid limit, must be > 0") tempPath string sftpFingerprints []string allowSelfConnections int renameMode int readMetadata int resumeMaxSize int64 uploadMode int ) var ( createPipeFn = func(dirPath string, _ int64) (pipeReaderAt, pipeWriterAt, error) { return pipeat.PipeInDir(dirPath) } ) // SetAllowSelfConnections sets the desired behaviour for self connections func SetAllowSelfConnections(value int) { allowSelfConnections = value } // SetTempPath sets the path for temporary files func SetTempPath(fsPath string) { tempPath = fsPath } // GetTempPath returns the path for temporary files func GetTempPath() string { return tempPath } // SetSFTPFingerprints sets the SFTP host key fingerprints func SetSFTPFingerprints(fp []string) { sftpFingerprints = fp } // SetRenameMode sets the rename mode func SetRenameMode(val int) { renameMode = val } // SetReadMetadataMode sets the read metadata mode func SetReadMetadataMode(val int) { readMetadata = val } // SetResumeMaxSize sets the max size allowed for resuming uploads for backends // with immutable objects func SetResumeMaxSize(val int64) { resumeMaxSize = val } // SetUploadMode sets the upload mode func SetUploadMode(val int) { uploadMode = val } // Fs defines the interface for filesystem backends type Fs interface { Name() string ConnectionID() string Stat(name string) (os.FileInfo, error) Lstat(name string) (os.FileInfo, error) Open(name string, offset int64) (File, PipeReader, func(), error) Create(name string, flag, checks int) (File, PipeWriter, func(), error) Rename(source, target string, checks int) (int, int64, error) Remove(name string, isDir bool) error Mkdir(name string) error Symlink(source, target string) error Chown(name string, uid int, gid int) error Chmod(name string, mode os.FileMode) error Chtimes(name string, atime, mtime time.Time, isUploading bool) error Truncate(name string, size int64) error ReadDir(dirname string) (DirLister, error) Readlink(name string) (string, error) IsUploadResumeSupported() bool IsConditionalUploadResumeSupported(size int64) bool IsAtomicUploadSupported() bool CheckRootPath(username string, uid int, gid int) bool ResolvePath(virtualPath string) (string, error) IsNotExist(err error) bool IsPermission(err error) bool IsNotSupported(err error) bool ScanRootDirContents() (int, int64, error) GetDirSize(dirname string) (int, int64, error) GetAtomicUploadPath(name string) string GetRelativePath(name string) string Walk(root string, walkFn filepath.WalkFunc) error Join(elem ...string) string HasVirtualFolders() bool GetMimeType(name string) (string, error) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) Close() error } // FsRealPather is a Fs that implements the RealPath method. type FsRealPather interface { Fs RealPath(p string) (string, error) } // FsFileCopier is a Fs that implements the CopyFile method. type FsFileCopier interface { Fs CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) } // File defines an interface representing a SFTPGo file type File interface { io.Reader io.Writer io.Closer io.ReaderAt io.WriterAt io.Seeker Stat() (os.FileInfo, error) Name() string Truncate(size int64) error } // PipeWriter defines an interface representing a SFTPGo pipe writer type PipeWriter interface { io.Writer io.WriterAt io.Closer Done(err error) GetWrittenBytes() int64 } // PipeReader defines an interface representing a SFTPGo pipe reader type PipeReader interface { io.Reader io.ReaderAt io.Closer setMetadata(value map[string]string) setMetadataFromPointerVal(value map[string]*string) Metadata() map[string]string } type pipeReaderAt interface { Read(p []byte) (int, error) ReadAt(p []byte, offset int64) (int, error) GetReadedBytes() int64 Close() error CloseWithError(err error) error } type pipeWriterAt interface { Write(p []byte) (int, error) WriteAt(p []byte, offset int64) (int, error) GetWrittenBytes() int64 Close() error CloseWithError(err error) error } // DirLister defines an interface for a directory lister type DirLister interface { Next(limit int) ([]os.FileInfo, error) Close() error } // Metadater defines an interface to implement to return metadata for a file type Metadater interface { Metadata() map[string]string } type baseDirLister struct { cache []os.FileInfo } func (l *baseDirLister) Next(limit int) ([]os.FileInfo, error) { if limit <= 0 { return nil, errInvalidDirListerLimit } if len(l.cache) >= limit { return l.returnFromCache(limit), nil } return l.returnFromCache(limit), io.EOF } func (l *baseDirLister) returnFromCache(limit int) []os.FileInfo { if len(l.cache) >= limit { result := l.cache[:limit] l.cache = l.cache[limit:] return result } result := l.cache l.cache = nil return result } func (l *baseDirLister) Close() error { l.cache = nil return nil } // QuotaCheckResult defines the result for a quota check type QuotaCheckResult struct { HasSpace bool AllowedSize int64 AllowedFiles int UsedSize int64 UsedFiles int QuotaSize int64 QuotaFiles int } // GetRemainingSize returns the remaining allowed size func (q *QuotaCheckResult) GetRemainingSize() int64 { if q.QuotaSize > 0 { return q.QuotaSize - q.UsedSize } return 0 } // GetRemainingFiles returns the remaining allowed files func (q *QuotaCheckResult) GetRemainingFiles() int { if q.QuotaFiles > 0 { return q.QuotaFiles - q.UsedFiles } return 0 } // S3FsConfig defines the configuration for S3 based filesystem type S3FsConfig struct { sdk.BaseS3FsConfig AccessSecret *kms.Secret `json:"access_secret,omitempty"` SSECustomerKey *kms.Secret `json:"sse_customer_key,omitempty"` } // HideConfidentialData hides confidential data func (c *S3FsConfig) HideConfidentialData() { if c.AccessSecret != nil { c.AccessSecret.Hide() } if c.SSECustomerKey != nil { c.SSECustomerKey.Hide() } } func (c *S3FsConfig) isEqual(other S3FsConfig) bool { if c.Bucket != other.Bucket { return false } if c.KeyPrefix != other.KeyPrefix { return false } if c.Region != other.Region { return false } if c.AccessKey != other.AccessKey { return false } if c.RoleARN != other.RoleARN { return false } if c.Endpoint != other.Endpoint { return false } if c.StorageClass != other.StorageClass { return false } if c.ACL != other.ACL { return false } if !c.areMultipartFieldsEqual(other) { return false } if c.ForcePathStyle != other.ForcePathStyle { return false } if c.SkipTLSVerify != other.SkipTLSVerify { return false } return c.isSecretEqual(other) } func (c *S3FsConfig) areMultipartFieldsEqual(other S3FsConfig) bool { if c.UploadPartSize != other.UploadPartSize { return false } if c.UploadConcurrency != other.UploadConcurrency { return false } if c.DownloadConcurrency != other.DownloadConcurrency { return false } if c.DownloadPartSize != other.DownloadPartSize { return false } if c.DownloadPartMaxTime != other.DownloadPartMaxTime { return false } if c.UploadPartMaxTime != other.UploadPartMaxTime { return false } return true } func (c *S3FsConfig) isSecretEqual(other S3FsConfig) bool { if c.SSECustomerKey == nil { c.SSECustomerKey = kms.NewEmptySecret() } if other.SSECustomerKey == nil { other.SSECustomerKey = kms.NewEmptySecret() } if !c.SSECustomerKey.IsEqual(other.SSECustomerKey) { return false } if c.AccessSecret == nil { c.AccessSecret = kms.NewEmptySecret() } if other.AccessSecret == nil { other.AccessSecret = kms.NewEmptySecret() } return c.AccessSecret.IsEqual(other.AccessSecret) } func (c *S3FsConfig) checkCredentials() error { if c.AccessKey == "" && !c.AccessSecret.IsEmpty() { return util.NewI18nError( errors.New("access_key cannot be empty with access_secret not empty"), util.I18nErrorAccessKeyRequired, ) } if c.AccessSecret.IsEmpty() && c.AccessKey != "" { return util.NewI18nError( errors.New("access_secret cannot be empty with access_key not empty"), util.I18nErrorAccessSecretRequired, ) } if c.AccessSecret.IsEncrypted() && !c.AccessSecret.IsValid() { return errors.New("invalid encrypted access_secret") } if !c.AccessSecret.IsEmpty() && !c.AccessSecret.IsValidInput() { return errors.New("invalid access_secret") } if c.SSECustomerKey.IsEncrypted() && !c.SSECustomerKey.IsValid() { return errors.New("invalid encrypted sse_customer_key") } if !c.SSECustomerKey.IsEmpty() && !c.SSECustomerKey.IsValidInput() { return errors.New("invalid sse_customer_key") } return nil } // ValidateAndEncryptCredentials validates the configuration and encrypts access secret if it is in plain text func (c *S3FsConfig) ValidateAndEncryptCredentials(additionalData string) error { if err := c.validate(); err != nil { var errI18n *util.I18nError errValidation := util.NewValidationError(fmt.Sprintf("could not validate s3config: %v", err)) if errors.As(err, &errI18n) { return util.NewI18nError(errValidation, errI18n.Message) } return util.NewI18nError(errValidation, util.I18nErrorFsValidation) } if c.AccessSecret.IsPlain() { c.AccessSecret.SetAdditionalData(additionalData) err := c.AccessSecret.Encrypt() if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt s3 access secret: %v", err)), util.I18nErrorFsValidation, ) } } if c.SSECustomerKey.IsPlain() { c.SSECustomerKey.SetAdditionalData(additionalData) err := c.SSECustomerKey.Encrypt() if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt s3 SSE customer key: %v", err)), util.I18nErrorFsValidation, ) } } return nil } func (c *S3FsConfig) checkPartSizeAndConcurrency() error { if c.UploadPartSize != 0 && (c.UploadPartSize < 5 || c.UploadPartSize > 2000) { return util.NewI18nError( errors.New("upload_part_size cannot be != 0, lower than 5 (MB) or greater than 2000 (MB)"), util.I18nErrorULPartSizeInvalid, ) } if c.UploadConcurrency < 0 || c.UploadConcurrency > 64 { return util.NewI18nError( fmt.Errorf("invalid upload concurrency: %v", c.UploadConcurrency), util.I18nErrorULConcurrencyInvalid, ) } if c.DownloadPartSize != 0 && (c.DownloadPartSize < 5 || c.DownloadPartSize > 2000) { return util.NewI18nError( errors.New("download_part_size cannot be != 0, lower than 5 (MB) or greater than 2000 (MB)"), util.I18nErrorDLPartSizeInvalid, ) } if c.DownloadConcurrency < 0 || c.DownloadConcurrency > 64 { return util.NewI18nError( fmt.Errorf("invalid download concurrency: %v", c.DownloadConcurrency), util.I18nErrorDLConcurrencyInvalid, ) } return nil } func (c *S3FsConfig) isSameResource(other S3FsConfig) bool { if c.Bucket != other.Bucket { return false } if c.Endpoint != other.Endpoint { return false } return c.Region == other.Region } // validate returns an error if the configuration is not valid func (c *S3FsConfig) validate() error { if c.AccessSecret == nil { c.AccessSecret = kms.NewEmptySecret() } if c.SSECustomerKey == nil { c.SSECustomerKey = kms.NewEmptySecret() } if c.Bucket == "" { return util.NewI18nError(errors.New("bucket cannot be empty"), util.I18nErrorBucketRequired) } // the region may be embedded within the endpoint for some S3 compatible // object storage, for example B2 if c.Endpoint == "" && c.Region == "" { return util.NewI18nError(errors.New("region cannot be empty"), util.I18nErrorRegionRequired) } if err := c.checkCredentials(); err != nil { return err } if c.KeyPrefix != "" { if strings.HasPrefix(c.KeyPrefix, "/") { return util.NewI18nError(errors.New("key_prefix cannot start with /"), util.I18nErrorKeyPrefixInvalid) } c.KeyPrefix = path.Clean(c.KeyPrefix) if !strings.HasSuffix(c.KeyPrefix, "/") { c.KeyPrefix += "/" } } c.StorageClass = strings.TrimSpace(c.StorageClass) c.ACL = strings.TrimSpace(c.ACL) return c.checkPartSizeAndConcurrency() } // GCSFsConfig defines the configuration for Google Cloud Storage based filesystem type GCSFsConfig struct { sdk.BaseGCSFsConfig Credentials *kms.Secret `json:"credentials,omitempty"` } // HideConfidentialData hides confidential data func (c *GCSFsConfig) HideConfidentialData() { if c.Credentials != nil { c.Credentials.Hide() } } // ValidateAndEncryptCredentials validates the configuration and encrypts credentials if they are in plain text func (c *GCSFsConfig) ValidateAndEncryptCredentials(additionalData string) error { if err := c.validate(); err != nil { var errI18n *util.I18nError errValidation := util.NewValidationError(fmt.Sprintf("could not validate GCS config: %v", err)) if errors.As(err, &errI18n) { return util.NewI18nError(errValidation, errI18n.Message) } return util.NewI18nError(errValidation, util.I18nErrorFsValidation) } if c.Credentials.IsPlain() { c.Credentials.SetAdditionalData(additionalData) err := c.Credentials.Encrypt() if err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt GCS credentials: %v", err)), util.I18nErrorFsValidation, ) } } return nil } func (c *GCSFsConfig) isEqual(other GCSFsConfig) bool { if c.Bucket != other.Bucket { return false } if c.KeyPrefix != other.KeyPrefix { return false } if c.AutomaticCredentials != other.AutomaticCredentials { return false } if c.StorageClass != other.StorageClass { return false } if c.ACL != other.ACL { return false } if c.UploadPartSize != other.UploadPartSize { return false } if c.UploadPartMaxTime != other.UploadPartMaxTime { return false } if c.Credentials == nil { c.Credentials = kms.NewEmptySecret() } if other.Credentials == nil { other.Credentials = kms.NewEmptySecret() } return c.Credentials.IsEqual(other.Credentials) } func (c *GCSFsConfig) isSameResource(other GCSFsConfig) bool { return c.Bucket == other.Bucket } // validate returns an error if the configuration is not valid func (c *GCSFsConfig) validate() error { //nolint:gocyclo if c.Credentials == nil || c.AutomaticCredentials == 1 { c.Credentials = kms.NewEmptySecret() } if c.Bucket == "" { return util.NewI18nError(errors.New("bucket cannot be empty"), util.I18nErrorBucketRequired) } if c.KeyPrefix != "" { if strings.HasPrefix(c.KeyPrefix, "/") { return util.NewI18nError(errors.New("key_prefix cannot start with /"), util.I18nErrorKeyPrefixInvalid) } c.KeyPrefix = path.Clean(c.KeyPrefix) if !strings.HasSuffix(c.KeyPrefix, "/") { c.KeyPrefix += "/" } } if c.Credentials.IsEncrypted() && !c.Credentials.IsValid() { return errors.New("invalid encrypted credentials") } if c.AutomaticCredentials == 0 && !c.Credentials.IsValidInput() { return util.NewI18nError(errors.New("invalid credentials"), util.I18nErrorFsCredentialsRequired) } c.StorageClass = strings.TrimSpace(c.StorageClass) c.ACL = strings.TrimSpace(c.ACL) if c.UploadPartSize < 0 || c.UploadPartSize > 2000 { c.UploadPartSize = 0 } if c.UploadPartMaxTime < 0 { c.UploadPartMaxTime = 0 } return nil } // AzBlobFsConfig defines the configuration for Azure Blob Storage based filesystem type AzBlobFsConfig struct { sdk.BaseAzBlobFsConfig // Storage Account Key leave blank to use SAS URL. // The access key is stored encrypted based on the kms configuration AccountKey *kms.Secret `json:"account_key,omitempty"` // Shared access signature URL, leave blank if using account/key SASURL *kms.Secret `json:"sas_url,omitempty"` } // HideConfidentialData hides confidential data func (c *AzBlobFsConfig) HideConfidentialData() { if c.AccountKey != nil { c.AccountKey.Hide() } if c.SASURL != nil { c.SASURL.Hide() } } func (c *AzBlobFsConfig) isEqual(other AzBlobFsConfig) bool { if c.Container != other.Container { return false } if c.AccountName != other.AccountName { return false } if c.Endpoint != other.Endpoint { return false } if c.SASURL.IsEmpty() { c.SASURL = kms.NewEmptySecret() } if other.SASURL.IsEmpty() { other.SASURL = kms.NewEmptySecret() } if !c.SASURL.IsEqual(other.SASURL) { return false } if c.KeyPrefix != other.KeyPrefix { return false } if c.UploadPartSize != other.UploadPartSize { return false } if c.UploadConcurrency != other.UploadConcurrency { return false } if c.DownloadPartSize != other.DownloadPartSize { return false } if c.DownloadConcurrency != other.DownloadConcurrency { return false } if c.UseEmulator != other.UseEmulator { return false } if c.AccessTier != other.AccessTier { return false } return c.isSecretEqual(other) } func (c *AzBlobFsConfig) isSecretEqual(other AzBlobFsConfig) bool { if c.AccountKey == nil { c.AccountKey = kms.NewEmptySecret() } if other.AccountKey == nil { other.AccountKey = kms.NewEmptySecret() } return c.AccountKey.IsEqual(other.AccountKey) } // ValidateAndEncryptCredentials validates the configuration and encrypts access secret if it is in plain text func (c *AzBlobFsConfig) ValidateAndEncryptCredentials(additionalData string) error { if err := c.validate(); err != nil { var errI18n *util.I18nError errValidation := util.NewValidationError(fmt.Sprintf("could not validate Azure Blob config: %v", err)) if errors.As(err, &errI18n) { return util.NewI18nError(errValidation, errI18n.Message) } return util.NewI18nError(errValidation, util.I18nErrorFsValidation) } if c.AccountKey.IsPlain() { c.AccountKey.SetAdditionalData(additionalData) if err := c.AccountKey.Encrypt(); err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt Azure blob account key: %v", err)), util.I18nErrorFsValidation, ) } } if c.SASURL.IsPlain() { c.SASURL.SetAdditionalData(additionalData) if err := c.SASURL.Encrypt(); err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt Azure blob SAS URL: %v", err)), util.I18nErrorFsValidation, ) } } return nil } func (c *AzBlobFsConfig) checkCredentials() error { if c.SASURL.IsPlain() { _, err := url.Parse(c.SASURL.GetPayload()) if err != nil { return util.NewI18nError(err, util.I18nErrorSASURLInvalid) } return nil } if c.SASURL.IsEncrypted() && !c.SASURL.IsValid() { return errors.New("invalid encrypted sas_url") } if !c.SASURL.IsEmpty() { return nil } if c.AccountName == "" { return util.NewI18nError(errors.New("account name is required"), util.I18nErrorAccountNameRequired) } if c.AccountKey.IsEncrypted() && !c.AccountKey.IsValid() { return errors.New("invalid encrypted account_key") } return nil } func (c *AzBlobFsConfig) checkPartSizeAndConcurrency() error { if c.UploadPartSize < 0 || c.UploadPartSize > 2000 { return util.NewI18nError( fmt.Errorf("invalid upload part size: %v", c.UploadPartSize), util.I18nErrorULPartSizeInvalid, ) } if c.UploadConcurrency < 0 || c.UploadConcurrency > 64 { return util.NewI18nError( fmt.Errorf("invalid upload concurrency: %v", c.UploadConcurrency), util.I18nErrorULConcurrencyInvalid, ) } if c.DownloadPartSize < 0 || c.DownloadPartSize > 2000 { return util.NewI18nError( fmt.Errorf("invalid download part size: %v", c.DownloadPartSize), util.I18nErrorDLPartSizeInvalid, ) } if c.DownloadConcurrency < 0 || c.DownloadConcurrency > 64 { return util.NewI18nError( fmt.Errorf("invalid upload concurrency: %v", c.DownloadConcurrency), util.I18nErrorDLConcurrencyInvalid, ) } return nil } func (c *AzBlobFsConfig) tryDecrypt() error { if err := c.AccountKey.TryDecrypt(); err != nil { return fmt.Errorf("unable to decrypt account key: %w", err) } if err := c.SASURL.TryDecrypt(); err != nil { return fmt.Errorf("unable to decrypt SAS URL: %w", err) } return nil } func (c *AzBlobFsConfig) isSameResource(other AzBlobFsConfig) bool { if c.AccountName != other.AccountName { return false } if c.Endpoint != other.Endpoint { return false } if c.SASURL == nil { c.SASURL = kms.NewEmptySecret() } if other.SASURL == nil { other.SASURL = kms.NewEmptySecret() } return c.SASURL.GetPayload() == other.SASURL.GetPayload() } // validate returns an error if the configuration is not valid func (c *AzBlobFsConfig) validate() error { if c.AccountKey == nil { c.AccountKey = kms.NewEmptySecret() } if c.SASURL == nil { c.SASURL = kms.NewEmptySecret() } // container could be embedded within SAS URL we check this at runtime if c.SASURL.IsEmpty() && c.Container == "" { return util.NewI18nError(errors.New("container cannot be empty"), util.I18nErrorContainerRequired) } if err := c.checkCredentials(); err != nil { return err } if c.KeyPrefix != "" { if strings.HasPrefix(c.KeyPrefix, "/") { return util.NewI18nError(errors.New("key_prefix cannot start with /"), util.I18nErrorKeyPrefixInvalid) } c.KeyPrefix = path.Clean(c.KeyPrefix) if !strings.HasSuffix(c.KeyPrefix, "/") { c.KeyPrefix += "/" } } if err := c.checkPartSizeAndConcurrency(); err != nil { return err } if !slices.Contains(validAzAccessTier, c.AccessTier) { return fmt.Errorf("invalid access tier %q, valid values: \"''%v\"", c.AccessTier, strings.Join(validAzAccessTier, ", ")) } return nil } // CryptFsConfig defines the configuration to store local files as encrypted type CryptFsConfig struct { sdk.OSFsConfig Passphrase *kms.Secret `json:"passphrase,omitempty"` } // HideConfidentialData hides confidential data func (c *CryptFsConfig) HideConfidentialData() { if c.Passphrase != nil { c.Passphrase.Hide() } } func (c *CryptFsConfig) isEqual(other CryptFsConfig) bool { if c.Passphrase == nil { c.Passphrase = kms.NewEmptySecret() } if other.Passphrase == nil { other.Passphrase = kms.NewEmptySecret() } return c.Passphrase.IsEqual(other.Passphrase) } // ValidateAndEncryptCredentials validates the configuration and encrypts the passphrase if it is in plain text func (c *CryptFsConfig) ValidateAndEncryptCredentials(additionalData string) error { if err := c.validate(); err != nil { var errI18n *util.I18nError errValidation := util.NewValidationError(fmt.Sprintf("could not validate crypt fs config: %v", err)) if errors.As(err, &errI18n) { return util.NewI18nError(errValidation, errI18n.Message) } return util.NewI18nError(errValidation, util.I18nErrorFsValidation) } if c.Passphrase.IsPlain() { c.Passphrase.SetAdditionalData(additionalData) if err := c.Passphrase.Encrypt(); err != nil { return util.NewI18nError( util.NewValidationError(fmt.Sprintf("could not encrypt Crypt fs passphrase: %v", err)), util.I18nErrorFsValidation, ) } } return nil } func (c *CryptFsConfig) isSameResource(other CryptFsConfig) bool { return c.Passphrase.GetPayload() == other.Passphrase.GetPayload() } // validate returns an error if the configuration is not valid func (c *CryptFsConfig) validate() error { if c.Passphrase == nil || c.Passphrase.IsEmpty() { return util.NewI18nError(errors.New("invalid passphrase"), util.I18nErrorPassphraseRequired) } if !c.Passphrase.IsValidInput() { return util.NewI18nError(errors.New("passphrase cannot be empty or invalid"), util.I18nErrorPassphraseRequired) } if c.Passphrase.IsEncrypted() && !c.Passphrase.IsValid() { return errors.New("invalid encrypted passphrase") } return nil } // pipeWriter defines a wrapper for a pipeWriterAt. type pipeWriter struct { pipeWriterAt err error done chan bool } // NewPipeWriter initializes a new PipeWriter func NewPipeWriter(w pipeWriterAt) PipeWriter { return &pipeWriter{ pipeWriterAt: w, err: nil, done: make(chan bool), } } // Close waits for the upload to end, closes the pipeWriterAt and returns an error if any. func (p *pipeWriter) Close() error { p.pipeWriterAt.Close() //nolint:errcheck // the returned error is always null <-p.done return p.err } // Done unlocks other goroutines waiting on Close(). // It must be called when the upload ends func (p *pipeWriter) Done(err error) { p.err = err p.done <- true } func newPipeWriterAtOffset(w pipeWriterAt, offset int64) PipeWriter { return &pipeWriterAtOffset{ pipeWriter: &pipeWriter{ pipeWriterAt: w, err: nil, done: make(chan bool), }, offset: offset, writeOffset: offset, } } type pipeWriterAtOffset struct { *pipeWriter offset int64 writeOffset int64 } func (p *pipeWriterAtOffset) WriteAt(buf []byte, off int64) (int, error) { if off < p.offset { return 0, fmt.Errorf("invalid offset %d, minimum accepted %d", off, p.offset) } return p.pipeWriter.WriteAt(buf, off-p.offset) } func (p *pipeWriterAtOffset) Write(buf []byte) (int, error) { n, err := p.WriteAt(buf, p.writeOffset) p.writeOffset += int64(n) return n, err } // NewPipeReader initializes a new PipeReader func NewPipeReader(r pipeReaderAt) PipeReader { return &pipeReader{ pipeReaderAt: r, } } // pipeReader defines a wrapper for pipeat.PipeReaderAt. type pipeReader struct { pipeReaderAt mu sync.RWMutex metadata map[string]string } func (p *pipeReader) setMetadata(value map[string]string) { p.mu.Lock() defer p.mu.Unlock() p.metadata = value } func (p *pipeReader) setMetadataFromPointerVal(value map[string]*string) { p.mu.Lock() defer p.mu.Unlock() if len(value) == 0 { p.metadata = nil return } p.metadata = map[string]string{} for k, v := range value { val := util.GetStringFromPointer(v) if val != "" { p.metadata[k] = val } } } // Metadata implements the Metadater interface func (p *pipeReader) Metadata() map[string]string { p.mu.RLock() defer p.mu.RUnlock() if len(p.metadata) == 0 { return nil } result := make(map[string]string) for k, v := range p.metadata { result[k] = v } return result } func isEqualityCheckModeValid(mode int) bool { return mode >= 0 || mode <= 1 } // isDirectory checks if a path exists and is a directory func isDirectory(fs Fs, path string) (bool, error) { fileInfo, err := fs.Stat(path) if err != nil { return false, err } return fileInfo.IsDir(), err } // IsLocalOsFs returns true if fs is a local filesystem implementation func IsLocalOsFs(fs Fs) bool { return fs.Name() == osFsName } // IsCryptOsFs returns true if fs is an encrypted local filesystem implementation func IsCryptOsFs(fs Fs) bool { return fs.Name() == cryptFsName } // IsSFTPFs returns true if fs is an SFTP filesystem func IsSFTPFs(fs Fs) bool { return strings.HasPrefix(fs.Name(), sftpFsName) } // IsHTTPFs returns true if fs is an HTTP filesystem func IsHTTPFs(fs Fs) bool { return strings.HasPrefix(fs.Name(), httpFsName) } // IsBufferedLocalOrSFTPFs returns true if this is a buffered SFTP or local filesystem func IsBufferedLocalOrSFTPFs(fs Fs) bool { if osFs, ok := fs.(*OsFs); ok { return osFs.writeBufferSize > 0 } if !IsSFTPFs(fs) { return false } return !fs.IsUploadResumeSupported() } // FsOpenReturnsFile returns true if fs.Open returns a *os.File handle func FsOpenReturnsFile(fs Fs) bool { if osFs, ok := fs.(*OsFs); ok { return osFs.readBufferSize == 0 } if sftpFs, ok := fs.(*SFTPFs); ok { return sftpFs.config.BufferSize == 0 } return false } // IsLocalOrSFTPFs returns true if fs is local or SFTP func IsLocalOrSFTPFs(fs Fs) bool { return IsLocalOsFs(fs) || IsSFTPFs(fs) } // HasTruncateSupport returns true if the fs supports truncate files func HasTruncateSupport(fs Fs) bool { return IsLocalOsFs(fs) || IsSFTPFs(fs) || IsHTTPFs(fs) } // IsRenameAtomic returns true if renaming a directory is supposed to be atomic func IsRenameAtomic(fs Fs) bool { if strings.HasPrefix(fs.Name(), s3fsName) { return false } if strings.HasPrefix(fs.Name(), gcsfsName) { return false } if strings.HasPrefix(fs.Name(), azBlobFsName) { return false } return true } // HasImplicitAtomicUploads returns true if the fs don't persists partial files on error func HasImplicitAtomicUploads(fs Fs) bool { if strings.HasPrefix(fs.Name(), s3fsName) { return uploadMode&4 == 0 } if strings.HasPrefix(fs.Name(), gcsfsName) { return uploadMode&8 == 0 } if strings.HasPrefix(fs.Name(), azBlobFsName) { return uploadMode&16 == 0 } return false } // HasOpenRWSupport returns true if the fs can open a file // for reading and writing at the same time func HasOpenRWSupport(fs Fs) bool { if IsLocalOsFs(fs) { return true } if IsSFTPFs(fs) && fs.IsUploadResumeSupported() { return true } return false } // IsLocalOrCryptoFs returns true if fs is local or local encrypted func IsLocalOrCryptoFs(fs Fs) bool { return IsLocalOsFs(fs) || IsCryptOsFs(fs) } // SetPathPermissions calls fs.Chown. // It does nothing for local filesystem on windows func SetPathPermissions(fs Fs, path string, uid int, gid int) { if uid == -1 && gid == -1 { return } if IsLocalOsFs(fs) { if runtime.GOOS == "windows" { return } } if err := fs.Chown(path, uid, gid); err != nil { fsLog(fs, logger.LevelWarn, "error chowning path %v: %v", path, err) } } // IsUploadResumeSupported returns true if resuming uploads is supported func IsUploadResumeSupported(fs Fs, size int64) bool { if fs.IsUploadResumeSupported() { return true } return fs.IsConditionalUploadResumeSupported(size) } func getLastModified(metadata map[string]string) int64 { if val, ok := metadata[lastModifiedField]; ok && val != "" { lastModified, err := strconv.ParseInt(val, 10, 64) if err == nil { return lastModified } } return 0 } func getAzureLastModified(metadata map[string]*string) int64 { for k, v := range metadata { if strings.EqualFold(k, lastModifiedField) { if val := util.GetStringFromPointer(v); val != "" { lastModified, err := strconv.ParseInt(val, 10, 64) if err == nil { return lastModified } } return 0 } } return 0 } func validateOSFsConfig(config *sdk.OSFsConfig) error { if config.ReadBufferSize < 0 || config.ReadBufferSize > 10 { return fmt.Errorf("invalid read buffer size must be between 0 and 10 MB") } if config.WriteBufferSize < 0 || config.WriteBufferSize > 10 { return fmt.Errorf("invalid write buffer size must be between 0 and 10 MB") } return nil } func doCopy(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { if buf == nil { buf = make([]byte, 32768) } for { nr, er := src.Read(buf) if nr > 0 { nw, ew := dst.Write(buf[0:nr]) if nw < 0 || nr < nw { nw = 0 if ew == nil { ew = errors.New("invalid write") } } written += int64(nw) if ew != nil { err = ew break } if nr != nw { err = io.ErrShortWrite break } } if er != nil { if er != io.EOF { err = er } break } } return written, err } func getMountPath(mountPath string) string { if mountPath == "/" { return "" } return mountPath } func getLocalTempDir() string { if tempPath != "" { return tempPath } return filepath.Clean(os.TempDir()) } func doRecursiveRename(fs Fs, source, target string, renameFn func(string, string, os.FileInfo, int, bool) (int, int64, error), recursion int, updateModTime bool, ) (int, int64, error) { var numFiles int var filesSize int64 if recursion > util.MaxRecursion { return numFiles, filesSize, util.ErrRecursionTooDeep } recursion++ lister, err := fs.ReadDir(source) if err != nil { return numFiles, filesSize, err } defer lister.Close() for { entries, err := lister.Next(ListerBatchSize) finished := errors.Is(err, io.EOF) if err != nil && !finished { return numFiles, filesSize, err } for _, info := range entries { sourceEntry := fs.Join(source, info.Name()) targetEntry := fs.Join(target, info.Name()) files, size, err := renameFn(sourceEntry, targetEntry, info, recursion, updateModTime) if err != nil { if fs.IsNotExist(err) { fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err) continue } return numFiles, filesSize, err } numFiles += files filesSize += size } if finished { return numFiles, filesSize, nil } } } // copied from rclone func readFill(r io.Reader, buf []byte) (n int, err error) { var nn int for n < len(buf) && err == nil { nn, err = r.Read(buf[n:]) n += nn } return n, err } func writeAtFull(w io.WriterAt, buf []byte, offset int64, count int) error { written := 0 for written < count { n, err := w.WriteAt(buf[written:count], offset+int64(written)) written += n if err != nil { return err } } return nil } type bytesReaderWrapper struct { *bytes.Reader } func (b *bytesReaderWrapper) Close() error { return nil } type bufferAllocator struct { sync.Mutex available [][]byte bufferSize int finalized bool } func newBufferAllocator(size int) *bufferAllocator { return &bufferAllocator{ bufferSize: size, finalized: false, } } func (b *bufferAllocator) getBuffer() []byte { b.Lock() defer b.Unlock() if len(b.available) > 0 { var result []byte truncLength := len(b.available) - 1 result = b.available[truncLength] b.available[truncLength] = nil b.available = b.available[:truncLength] return result } return make([]byte, b.bufferSize) } func (b *bufferAllocator) releaseBuffer(buf []byte) { b.Lock() defer b.Unlock() if b.finalized || len(buf) != b.bufferSize { return } b.available = append(b.available, buf) } func (b *bufferAllocator) free() { b.Lock() defer b.Unlock() b.available = nil b.finalized = true } func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) { logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...) } ================================================ FILE: internal/webdavd/file.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package webdavd import ( "context" "encoding/xml" "errors" "io" "mime" "net/http" "os" "path" "slices" "sync/atomic" "time" "github.com/drakkan/webdav" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) var ( errTransferAborted = errors.New("transfer aborted") lastModifiedProps = []string{"Win32LastModifiedTime", "getlastmodified"} ) type webDavFile struct { *common.BaseTransfer writer io.WriteCloser reader io.ReadCloser info os.FileInfo startOffset int64 isFinished bool readTried atomic.Bool } func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader) *webDavFile { var writer io.WriteCloser var reader io.ReadCloser if baseTransfer.File != nil { writer = baseTransfer.File reader = baseTransfer.File } else if pipeWriter != nil { writer = pipeWriter } else if pipeReader != nil { reader = pipeReader } f := &webDavFile{ BaseTransfer: baseTransfer, writer: writer, reader: reader, isFinished: false, startOffset: 0, info: nil, } f.readTried.Store(false) return f } type webDavFileInfo struct { os.FileInfo Fs vfs.Fs virtualPath string fsPath string } // ContentType implements webdav.ContentTyper interface func (fi *webDavFileInfo) ContentType(_ context.Context) (string, error) { extension := path.Ext(fi.virtualPath) if ctype, ok := customMimeTypeMapping[extension]; ok { return ctype, nil } if extension == "" || extension == ".dat" { return "application/octet-stream", nil } contentType := mime.TypeByExtension(extension) if contentType != "" { return contentType, nil } contentType = mimeTypeCache.getMimeFromCache(extension) if contentType != "" { return contentType, nil } contentType, err := fi.Fs.GetMimeType(fi.fsPath) if contentType != "" { mimeTypeCache.addMimeToCache(extension, contentType) return contentType, err } return "", webdav.ErrNotImplemented } // Readdir reads directory entries from the handle func (f *webDavFile) Readdir(_ int) ([]os.FileInfo, error) { return nil, webdav.ErrNotImplemented } // ReadDir implements the FileDirLister interface func (f *webDavFile) ReadDir() (webdav.DirLister, error) { if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) { return nil, f.Connection.GetPermissionDeniedError() } lister, err := f.Connection.ListDir(f.GetVirtualPath()) if err != nil { return nil, err } return &webDavDirLister{ DirLister: lister, fs: f.Fs, virtualDirPath: f.GetVirtualPath(), fsDirPath: f.GetFsPath(), }, nil } // Stat the handle func (f *webDavFile) Stat() (os.FileInfo, error) { if f.GetType() == common.TransferDownload && !f.Connection.User.HasPerm(dataprovider.PermListItems, path.Dir(f.GetVirtualPath())) { return nil, f.Connection.GetPermissionDeniedError() } f.Lock() errUpload := f.ErrTransfer f.Unlock() if f.GetType() == common.TransferUpload && errUpload == nil { info := &webDavFileInfo{ FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, f.BytesReceived.Load(), time.Now(), false), Fs: f.Fs, virtualPath: f.GetVirtualPath(), fsPath: f.GetFsPath(), } return info, nil } info, err := f.Fs.Stat(f.GetFsPath()) if err != nil { return nil, f.Connection.GetFsError(f.Fs, err) } if vfs.IsCryptOsFs(f.Fs) { info = f.Fs.(*vfs.CryptFs).ConvertFileInfo(info) } fi := &webDavFileInfo{ FileInfo: info, Fs: f.Fs, virtualPath: f.GetVirtualPath(), fsPath: f.GetFsPath(), } return fi, nil } func (f *webDavFile) checkFirstRead() error { if !f.Connection.User.HasPerm(dataprovider.PermDownload, path.Dir(f.GetVirtualPath())) { return f.Connection.GetPermissionDeniedError() } transferQuota := f.GetTransferQuota() if !transferQuota.HasDownloadSpace() { f.Connection.Log(logger.LevelInfo, "denying file read due to quota limits") return f.Connection.GetReadQuotaExceededError() } if ok, policy := f.Connection.User.IsFileAllowed(f.GetVirtualPath()); !ok { f.Connection.Log(logger.LevelWarn, "reading file %q is not allowed", f.GetVirtualPath()) return f.Connection.GetErrorForDeniedFile(policy) } _, err := common.ExecutePreAction(f.Connection, common.OperationPreDownload, f.GetFsPath(), f.GetVirtualPath(), 0, 0) if err != nil { f.Connection.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", f.GetVirtualPath(), err) return f.Connection.GetPermissionDeniedError() } f.readTried.Store(true) return nil } // Read reads the contents to downloads. func (f *webDavFile) Read(p []byte) (n int, err error) { if f.AbortTransfer.Load() { return 0, errTransferAborted } if !f.readTried.Load() { if err := f.checkFirstRead(); err != nil { return 0, err } } f.Connection.UpdateLastActivity() // the file is read sequentially we don't need to check for concurrent reads and so // lock the transfer while opening the remote file if f.reader == nil { if f.GetType() != common.TransferDownload { f.TransferError(common.ErrOpUnsupported) return 0, common.ErrOpUnsupported } file, r, cancelFn, e := f.Fs.Open(f.GetFsPath(), 0) f.Lock() if e == nil { if file != nil { f.File = file f.writer = f.File f.reader = f.File } else if r != nil { f.reader = r } f.SetCancelFn(cancelFn) } f.ErrTransfer = e f.startOffset = 0 f.Unlock() if e != nil { return 0, f.Connection.GetFsError(f.Fs, e) } } n, err = f.reader.Read(p) f.BytesSent.Add(int64(n)) if err == nil { err = f.CheckRead() } if err != nil && err != io.EOF { f.TransferError(err) err = f.ConvertError(err) return } f.HandleThrottle() return } // Write writes the uploaded contents. func (f *webDavFile) Write(p []byte) (n int, err error) { if f.AbortTransfer.Load() { return 0, errTransferAborted } f.Connection.UpdateLastActivity() n, err = f.writer.Write(p) f.BytesReceived.Add(int64(n)) if err == nil { err = f.CheckWrite() } if err != nil { f.TransferError(err) err = f.ConvertError(err) return } f.HandleThrottle() return } func (f *webDavFile) updateStatInfo() error { if f.info != nil { return nil } info, err := f.Fs.Stat(f.GetFsPath()) if err != nil { return err } if vfs.IsCryptOsFs(f.Fs) { info = f.Fs.(*vfs.CryptFs).ConvertFileInfo(info) } f.info = info return nil } func (f *webDavFile) updateTransferQuotaOnSeek() { transferQuota := f.GetTransferQuota() if transferQuota.HasSizeLimits() { go func(ulSize, dlSize int64, user dataprovider.User) { dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck }(f.BytesReceived.Load(), f.BytesSent.Load(), f.Connection.User) } } func (f *webDavFile) checkFile() error { if f.File == nil && vfs.FsOpenReturnsFile(f.Fs) { file, _, _, err := f.Fs.Open(f.GetFsPath(), 0) if err != nil { f.Connection.Log(logger.LevelWarn, "could not open file %q for seeking: %v", f.GetFsPath(), err) f.TransferError(err) return err } f.File = file f.reader = file f.writer = file } return nil } func (f *webDavFile) seekFile(offset int64, whence int) (int64, error) { ret, err := f.File.Seek(offset, whence) if err != nil { f.TransferError(err) } return ret, err } // Seek sets the offset for the next Read or Write on the writer to offset, // interpreted according to whence: 0 means relative to the origin of the file, // 1 means relative to the current offset, and 2 means relative to the end. // It returns the new offset and an error, if any. func (f *webDavFile) Seek(offset int64, whence int) (int64, error) { f.Connection.UpdateLastActivity() if err := f.checkFile(); err != nil { return 0, err } if f.File != nil { return f.seekFile(offset, whence) } if f.GetType() == common.TransferDownload { readOffset := f.startOffset + f.BytesSent.Load() if offset == 0 && readOffset == 0 { switch whence { case io.SeekStart: return 0, nil case io.SeekEnd: if err := f.updateStatInfo(); err != nil { return 0, err } return f.info.Size(), nil } } // close the reader and create a new one at startByte if f.reader != nil { f.reader.Close() //nolint:errcheck f.reader = nil } startByte := int64(0) f.BytesReceived.Store(0) f.BytesSent.Store(0) f.updateTransferQuotaOnSeek() switch whence { case io.SeekStart: startByte = offset case io.SeekCurrent: startByte = readOffset + offset case io.SeekEnd: if err := f.updateStatInfo(); err != nil { f.TransferError(err) return 0, err } startByte = f.info.Size() - offset } _, r, cancelFn, err := f.Fs.Open(f.GetFsPath(), startByte) f.Lock() if err == nil { f.startOffset = startByte f.reader = r } f.ErrTransfer = err f.SetCancelFn(cancelFn) f.Unlock() return startByte, err } return 0, common.ErrOpUnsupported } // Close closes the open directory or the current transfer func (f *webDavFile) Close() error { if err := f.setFinished(); err != nil { return err } err := f.closeIO() if f.isTransfer() { errBaseClose := f.BaseTransfer.Close() if errBaseClose != nil { err = errBaseClose } } else { f.Connection.RemoveTransfer(f.BaseTransfer) } return f.Connection.GetFsError(f.Fs, err) } func (f *webDavFile) closeIO() error { var err error if f.File != nil { err = f.File.Close() } else if f.writer != nil { err = f.writer.Close() f.Lock() // we set ErrTransfer here so quota is not updated, in this case the uploads are atomic if err != nil && f.ErrTransfer == nil { f.ErrTransfer = err } f.Unlock() } else if f.reader != nil { err = f.reader.Close() if metadater, ok := f.reader.(vfs.Metadater); ok { f.SetMetadata(metadater.Metadata()) } } return err } func (f *webDavFile) setFinished() error { f.Lock() defer f.Unlock() if f.isFinished { return common.ErrTransferClosed } f.isFinished = true return nil } func (f *webDavFile) isTransfer() bool { if f.GetType() == common.TransferDownload { return f.readTried.Load() } return true } // DeadProps returns a copy of the dead properties held. // We always return nil for now, we only support the last modification time // and it is already included in "live" properties func (f *webDavFile) DeadProps() (map[xml.Name]webdav.Property, error) { return nil, nil } // Patch patches the dead properties held. // In our minimal implementation we just support Win32LastModifiedTime and // getlastmodified to set the the modification time. // We ignore any other property and just return an OK response if the patch sets // the modification time, otherwise a Forbidden response func (f *webDavFile) Patch(patches []webdav.Proppatch) ([]webdav.Propstat, error) { resp := make([]webdav.Propstat, 0, len(patches)) hasError := false for _, patch := range patches { status := http.StatusForbidden pstat := webdav.Propstat{} for _, p := range patch.Props { if status == http.StatusForbidden && !hasError { if !patch.Remove && slices.Contains(lastModifiedProps, p.XMLName.Local) { parsed, err := parseTime(util.BytesToString(p.InnerXML)) if err != nil { f.Connection.Log(logger.LevelWarn, "unsupported last modification time: %q, err: %v", p.InnerXML, err) hasError = true continue } attrs := &common.StatAttributes{ Flags: common.StatAttrTimes, Atime: parsed, Mtime: parsed, } if err := f.Connection.SetStat(f.GetVirtualPath(), attrs); err != nil { f.Connection.Log(logger.LevelWarn, "unable to set modification time for %q, err :%v", f.GetVirtualPath(), err) hasError = true continue } status = http.StatusOK } } pstat.Props = append(pstat.Props, webdav.Property{XMLName: p.XMLName}) } pstat.Status = status resp = append(resp, pstat) } return resp, nil } type webDavDirLister struct { vfs.DirLister fs vfs.Fs virtualDirPath string fsDirPath string } func (l *webDavDirLister) Next(limit int) ([]os.FileInfo, error) { files, err := l.DirLister.Next(limit) for idx := range files { info := files[idx] files[idx] = &webDavFileInfo{ FileInfo: info, Fs: l.fs, virtualPath: path.Join(l.virtualDirPath, info.Name()), fsPath: l.fs.Join(l.fsDirPath, info.Name()), } } return files, err } ================================================ FILE: internal/webdavd/handler.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package webdavd import ( "context" "net/http" "os" "path" "strconv" "strings" "time" "github.com/drakkan/webdav" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) // Connection details for a WebDav connection. type Connection struct { *common.BaseConnection request *http.Request rc *http.ResponseController } func newConnection(conn *common.BaseConnection, w http.ResponseWriter, r *http.Request) *Connection { rc := http.NewResponseController(w) responseControllerDeadlines(rc, time.Time{}, time.Time{}) return &Connection{ BaseConnection: conn, request: r, rc: rc, } } func (c *Connection) getModificationTime() time.Time { if c.request == nil { return time.Time{} } if val := c.request.Header.Get("X-OC-Mtime"); val != "" { if unixTime, err := strconv.ParseInt(val, 10, 64); err == nil { return time.Unix(unixTime, 0) } } return time.Time{} } // GetClientVersion returns the connected client's version. func (c *Connection) GetClientVersion() string { if c.request != nil { return c.request.UserAgent() } return "" } // GetLocalAddress returns local connection address func (c *Connection) GetLocalAddress() string { return util.GetHTTPLocalAddress(c.request) } // GetRemoteAddress returns the connected client's address func (c *Connection) GetRemoteAddress() string { if c.request != nil { return c.request.RemoteAddr } return "" } // Disconnect closes the active transfer func (c *Connection) Disconnect() error { if c.rc != nil { responseControllerDeadlines(c.rc, time.Now().Add(5*time.Second), time.Now().Add(5*time.Second)) } return c.SignalTransfersAbort() } // GetCommand returns the request method func (c *Connection) GetCommand() string { if c.request != nil { return strings.ToUpper(c.request.Method) } return "" } // Mkdir creates a directory using the connection filesystem func (c *Connection) Mkdir(_ context.Context, name string, _ os.FileMode) error { c.UpdateLastActivity() name = util.CleanPath(name) return c.CreateDir(name, true) } // Rename renames a file or a directory func (c *Connection) Rename(_ context.Context, oldName, newName string) error { c.UpdateLastActivity() oldName = util.CleanPath(oldName) newName = util.CleanPath(newName) err := c.BaseConnection.Rename(oldName, newName) if err == nil { if mtime := c.getModificationTime(); !mtime.IsZero() { attrs := &common.StatAttributes{ Flags: common.StatAttrTimes, Atime: mtime, Mtime: mtime, } setStatErr := c.SetStat(newName, attrs) c.Log(logger.LevelDebug, "mtime header found for %q, value: %s, err: %v", newName, mtime, setStatErr) } } return err } // Stat returns a FileInfo describing the named file/directory, or an error, // if any happens func (c *Connection) Stat(_ context.Context, name string) (os.FileInfo, error) { c.UpdateLastActivity() name = util.CleanPath(name) if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(name)) { return nil, c.GetPermissionDeniedError() } fi, err := c.DoStat(name, 0, true) if err != nil { return nil, err } return fi, err } // RemoveAll removes path and any children it contains. // If the path does not exist, RemoveAll returns nil (no error). func (c *Connection) RemoveAll(_ context.Context, name string) error { c.UpdateLastActivity() name = util.CleanPath(name) return c.BaseConnection.RemoveAll(name) } // OpenFile opens the named file with specified flag. // This method is used for uploads and downloads but also for Stat and Readdir func (c *Connection) OpenFile(_ context.Context, name string, flag int, _ os.FileMode) (webdav.File, error) { c.UpdateLastActivity() if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { c.Log(logger.LevelInfo, "denying transfer due to count limits") return nil, c.GetPermissionDeniedError() } name = util.CleanPath(name) fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { return nil, err } if flag == os.O_RDONLY || c.request.Method == "PROPPATCH" { // Download, Stat, Readdir or simply open/close return c.getFile(fs, p, name) } return c.putFile(fs, p, name) } func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File, error) { var cancelFn func() // we open the file when we receive the first read so we only open the file if necessary baseTransfer := common.NewBaseTransfer(nil, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, common.TransferDownload, 0, 0, 0, 0, false, fs, c.GetTransferQuota()) return newWebDavFile(baseTransfer, nil, nil), nil } func (c *Connection) putFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File, error) { if ok, _ := c.User.IsFileAllowed(virtualPath); !ok { c.Log(logger.LevelWarn, "writing file %q is not allowed", virtualPath) return nil, c.GetPermissionDeniedError() } filePath := fsPath if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { filePath = fs.GetAtomicUploadPath(fsPath) } stat, statErr := fs.Lstat(fsPath) if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(virtualPath)) { return nil, c.GetPermissionDeniedError() } return c.handleUploadToNewFile(fs, fsPath, filePath, virtualPath) } if statErr != nil { c.Log(logger.LevelError, "error performing file stat %q: %+v", fsPath, statErr) return nil, c.GetFsError(fs, statErr) } // This happen if we upload a file that has the same name of an existing directory if stat.IsDir() { c.Log(logger.LevelError, "attempted to open a directory for writing to: %q", fsPath) return nil, c.GetOpUnsupportedError() } if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(virtualPath)) { return nil, c.GetPermissionDeniedError() } return c.handleUploadToExistingFile(fs, fsPath, filePath, stat.Size(), virtualPath) } func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string) (webdav.File, error) { diskQuota, transferQuota := c.HasSpace(true, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, common.ErrQuotaExceeded } if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, 0, 0); err != nil { c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) return nil, c.GetPermissionDeniedError() } file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.GetCreateChecks(requestPath, true, false)) if err != nil { c.Log(logger.LevelError, "error creating file %q: %+v", resolvedPath, err) return nil, c.GetFsError(fs, err) } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) mtime := c.getModificationTime() baseTransfer.SetTimes(resolvedPath, mtime, mtime) return newWebDavFile(baseTransfer, w, nil), nil } func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePath string, fileSize int64, requestPath string, ) (webdav.File, error) { var err error diskQuota, transferQuota := c.HasSpace(false, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, common.ErrQuotaExceeded } if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, fileSize, os.O_TRUNC); err != nil { c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) return nil, c.GetPermissionDeniedError() } // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { _, _, err = fs.Rename(resolvedPath, filePath, 0) if err != nil { c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v", resolvedPath, filePath, err) return nil, c.GetFsError(fs, err) } } file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.GetCreateChecks(requestPath, false, false)) if err != nil { c.Log(logger.LevelError, "error creating file %q: %+v", resolvedPath, err) return nil, c.GetFsError(fs, err) } initialSize := int64(0) truncatedSize := int64(0) // bytes truncated and not included in quota if vfs.HasTruncateSupport(fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false) } else { dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { initialSize = fileSize truncatedSize = fileSize } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) mtime := c.getModificationTime() baseTransfer.SetTimes(resolvedPath, mtime, mtime) return newWebDavFile(baseTransfer, w, nil), nil } ================================================ FILE: internal/webdavd/internal_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package webdavd import ( "context" "crypto/tls" "crypto/x509" "encoding/xml" "fmt" "io" "net/http" "net/http/httptest" "os" "path" "path/filepath" "runtime" "testing" "time" "github.com/drakkan/webdav" "github.com/eikenb/pipeat" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) const ( testFile = "test_dav_file" webDavCert = `-----BEGIN CERTIFICATE----- MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM 3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY /8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI dV4vKmHUzwK/eIx+8Ay3neE= -----END CERTIFICATE-----` webDavKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= -----END EC PRIVATE KEY-----` caCRT = `-----BEGIN CERTIFICATE----- MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U 4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY /tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w kK8u1YiiVenmoQ== -----END CERTIFICATE-----` caKey = `-----BEGIN RSA PRIVATE KEY----- MIIJKgIBAAKCAgEA7WHW216mfi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7x b64rkpdzx1aWetSiCrEyc3D1v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvD OBUYZgtMqHZzpE6xRrqQ84zhyzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKz n/2uEVt33qmO85WtN3RzbSqLCdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj 7B5P5MeamkkogwbExUjdHp3U4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZ De67V/Q8iB2May1k7zBz1ZtbKF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmOb cn8AIfH6smLQrn0C3cs7CYfoNlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLq R/BJTbbyXUB0imne1u00fuzbS7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb7 8x7ivdyXSF5LVQJ1JvhhWu6iM6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpB P8d2jcRZVUVrSXGc2mAGuGOY/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2Z xugCXULtRWJ9p4C9zUl40HEyOQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEA AQKCAgEA4x0OoceG54ZrVxifqVaQd8qw3uRmUKUMIMdfuMlsdideeLO97ynmSlRY 00kGo/I4Lp6mNEjI9gUie9+uBrcUhri4YLcujHCH+YlNnCBDbGjwbe0ds9SLCWaa KztZHMSlW5Q4Bqytgu+MpOnxSgqjlOk+vz9TcGFKVnUkHIkAcqKFJX8gOFxPZA/t Ob1kJaz4kuv5W2Kur/ISKvQtvFvOtQeV0aJyZm8LqXnvS4cPI7yN4329NDU0HyDR y/deqS2aqV4zII3FFqbz8zix/m1xtVQzWCugZGMKrz0iuJMfNeCABb8rRGc6GsZz +465v/kobqgeyyneJ1s5rMFrLp2o+dwmnIVMNsFDUiN1lIZDHLvlgonaUO3IdTZc 9asamFWKFKUMgWqM4zB1vmUO12CKowLNIIKb0L+kf1ixaLLDRGf/f9vLtSHE+oyx lATiS18VNA8+CGsHF6uXMRwf2auZdRI9+s6AAeyRISSbO1khyWKHo+bpOvmPAkDR nknTjbYgkoZOV+mrsU5oxV8s6vMkuvA3rwFhT2gie8pokuACFcCRrZi9MVs4LmUQ u0GYTHvp2WJUjMWBm6XX7Hk3g2HV842qpk/mdtTjNsXws81djtJPn4I/soIXSgXz pY3SvKTuOckP9OZVF0yqKGeZXKpD288PKpC+MAg3GvEJaednagECggEBAPsfLwuP L1kiDjXyMcRoKlrQ6Q/zBGyBmJbZ5uVGa02+XtYtDAzLoVupPESXL0E7+r8ZpZ39 0dV4CEJKpbVS/BBtTEkPpTK5kz778Ib04TAyj+YLhsZjsnuja3T5bIBZXFDeDVDM 0ZaoFoKpIjTu2aO6pzngsgXs6EYbo2MTuJD3h0nkGZsICL7xvT9Mw0P1p2Ftt/hN +jKk3vN220wTWUsq43AePi45VwK+PNP12ZXv9HpWDxlPo3j0nXtgYXittYNAT92u BZbFAzldEIX9WKKZgsWtIzLaASjVRntpxDCTby/nlzQ5dw3DHU1DV3PIqxZS2+Oe KV+7XFWgZ44YjYECggEBAPH+VDu3QSrqSahkZLkgBtGRkiZPkZFXYvU6kL8qf5wO Z/uXMeqHtznAupLea8I4YZLfQim/NfC0v1cAcFa9Ckt9g3GwTSirVcN0AC1iOyv3 /hMZCA1zIyIcuUplNr8qewoX71uPOvCNH0dix77423mKFkJmNwzy4Q+rV+qkRdLn v+AAgh7g5N91pxNd6LQJjoyfi1Ka6rRP2yGXM5v7QOwD16eN4JmExUxX1YQ7uNuX pVS+HRxnBquA+3/DB1LtBX6pa2cUa+LRUmE/NCPHMvJcyuNkYpJKlNTd9vnbfo0H RNSJSWm+aGxDFMjuPjV3JLj2OdKMPwpnXdh2vBZCPpMCggEAM+yTvrEhmi2HgLIO hkz/jP2rYyfdn04ArhhqPLgd0dpuI5z24+Jq/9fzZT9ZfwSW6VK1QwDLlXcXRhXH Q8Hf6smev3CjuORURO61IkKaGWwrAucZPAY7ToNQ4cP9ImDXzMTNPgrLv3oMBYJR V16X09nxX+9NABqnQG/QjdjzDc6Qw7+NZ9f2bvzvI5qMuY2eyW91XbtJ45ThoLfP ymAp03gPxQwL0WT7z85kJ3OrROxzwaPvxU0JQSZbNbqNDPXmFTiECxNDhpRAAWlz 1DC5Vg2l05fkMkyPdtD6nOQWs/CYSfB5/EtxiX/xnBszhvZUIe6KFvuKFIhaJD5h iykagQKCAQEAoBRm8k3KbTIo4ZzvyEq4V/+dF3zBRczx6FkCkYLygXBCNvsQiR2Y BjtI8Ijz7bnQShEoOmeDriRTAqGGrspEuiVgQ1+l2wZkKHRe/aaij/Zv+4AuhH8q uZEYvW7w5Uqbs9SbgQzhp2kjTNy6V8lVnjPLf8cQGZ+9Y9krwktC6T5m/i435WdN 38h7amNP4XEE/F86Eb3rDrZYtgLIoCF4E+iCyxMehU+AGH1uABhls9XAB6vvo+8/ SUp8lEqWWLP0U5KNOtYWfCeOAEiIHDbUq+DYUc4BKtbtV1cx3pzlPTOWw6XBi5Lq jttdL4HyYvnasAQpwe8GcMJqIRyCVZMiwwKCAQEAhQTTS3CC8PwcoYrpBdTjW1ck vVFeF1YbfqPZfYxASCOtdx6wRnnEJ+bjqntagns9e88muxj9UhxSL6q9XaXQBD8+ 2AmKUxphCZQiYFZcTucjQEQEI2nN+nAKgRrUSMMGiR8Ekc2iFrcxBU0dnSohw+aB PbMKVypQCREu9PcDFIp9rXQTeElbaNsIg1C1w/SQjODbmN/QFHTVbRODYqLeX1J/ VcGsykSIq7hv6bjn7JGkr2JTdANbjk9LnMjMdJFsKRYxPKkOQfYred6Hiojp5Sor PW5am8ejnNSPhIfqQp3uV3KhwPDKIeIpzvrB4uPfTjQWhekHCb8cKSWux3flqw== -----END RSA PRIVATE KEY-----` caCRL = `-----BEGIN X509 CRL----- MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV +jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn w7bXJGfJcIMrrKs= -----END X509 CRL-----` client1Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au MU3Bo0A= -----END CERTIFICATE-----` client1Key = `-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi 85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== -----END RSA PRIVATE KEY-----` // client 2 crt is revoked client2Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB 4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO 78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv /ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR E0+r2+4= -----END CERTIFICATE-----` client2Key = `-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu 4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl /owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz 3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D 8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl -----END RSA PRIVATE KEY-----` osWindows = "windows" ) // MockOsFs mockable OsFs type MockOsFs struct { vfs.Fs err error isAtomicUploadSupported bool reader *pipeat.PipeReaderAt } // Name returns the name for the Fs implementation func (fs *MockOsFs) Name() string { return "mockOsFs" } // Open returns nil func (fs *MockOsFs) Open(name string, offset int64) (vfs.File, vfs.PipeReader, func(), error) { if fs.reader != nil { return nil, vfs.NewPipeReader(fs.reader), nil, nil } return fs.Fs.Open(name, offset) } // IsUploadResumeSupported returns true if resuming uploads is supported func (*MockOsFs) IsUploadResumeSupported() bool { return false } // IsAtomicUploadSupported returns true if atomic upload is supported func (fs *MockOsFs) IsAtomicUploadSupported() bool { return fs.isAtomicUploadSupported } // Remove removes the named file or (empty) directory. func (fs *MockOsFs) Remove(name string, _ bool) error { if fs.err != nil { return fs.err } return os.Remove(name) } // Rename renames (moves) source to target func (fs *MockOsFs) Rename(source, target string, _ int) (int, int64, error) { err := os.Rename(source, target) return -1, -1, err } // GetMimeType returns the content type func (fs *MockOsFs) GetMimeType(_ string) (string, error) { if fs.err != nil { return "", fs.err } return "application/custom-mime", nil } func newMockOsFs(atomicUpload bool, connectionID, rootDir string, reader *pipeat.PipeReaderAt, err error) vfs.Fs { return &MockOsFs{ Fs: vfs.NewOsFs(connectionID, rootDir, "", nil), isAtomicUploadSupported: atomicUpload, reader: reader, err: err, } } func TestUserInvalidParams(t *testing.T) { u := &dataprovider.User{ BaseUser: sdk.BaseUser{ Username: "username", HomeDir: "invalid", }, } c := &Configuration{ Bindings: []Binding{ { Port: 9000, }, }, } server := webDavServer{ config: c, binding: c.Bindings[0], } req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", u.Username), nil) assert.NoError(t, err) _, err = server.validateUser(u, req, dataprovider.LoginMethodPassword) if assert.Error(t, err) { assert.EqualError(t, err, fmt.Sprintf("cannot login user with invalid home dir: %q", u.HomeDir)) } req.TLS = &tls.ConnectionState{} writeLog(req, http.StatusOK, nil) } func TestAllowedProxyUnixDomainSocket(t *testing.T) { b := Binding{ Address: filepath.Join(os.TempDir(), "sock"), ProxyAllowed: []string{"127.0.0.1", "127.0.1.1"}, } err := b.parseAllowedProxy() assert.NoError(t, err) if assert.Len(t, b.allowHeadersFrom, 1) { assert.True(t, b.allowHeadersFrom[0](nil)) } } func TestProxyListenerWrapper(t *testing.T) { b := Binding{ ProxyMode: 0, } require.Nil(t, b.listenerWrapper()) b.ProxyMode = 1 require.NotNil(t, b.listenerWrapper()) } func TestRemoteAddress(t *testing.T) { remoteAddr1 := "100.100.100.100" remoteAddr2 := "172.172.172.172" c := &Configuration{ Bindings: []Binding{ { Port: 9000, ProxyAllowed: []string{remoteAddr2, "10.8.0.0/30"}, }, }, } server := webDavServer{ config: c, binding: c.Bindings[0], } err := server.binding.parseAllowedProxy() assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, "/", nil) assert.NoError(t, err) assert.Empty(t, req.RemoteAddr) trueClientIP := "True-Client-IP" cfConnectingIP := "CF-Connecting-IP" xff := "X-Forwarded-For" xRealIP := "X-Real-IP" req.Header.Set(trueClientIP, remoteAddr1) ip := util.GetRealIP(req, trueClientIP, 0) assert.Equal(t, remoteAddr1, ip) ip = util.GetRealIP(req, trueClientIP, 2) assert.Empty(t, ip) req.Header.Del(trueClientIP) req.Header.Set(cfConnectingIP, remoteAddr1) ip = util.GetRealIP(req, cfConnectingIP, 0) assert.Equal(t, remoteAddr1, ip) req.Header.Del(cfConnectingIP) req.Header.Set(xff, remoteAddr1) ip = util.GetRealIP(req, xff, 0) assert.Equal(t, remoteAddr1, ip) // this will be ignored, remoteAddr1 is not allowed to se this header req.Header.Set(xff, remoteAddr2) req.RemoteAddr = remoteAddr1 ip = server.checkRemoteAddress(req) assert.Equal(t, remoteAddr1, ip) req.RemoteAddr = "" ip = server.checkRemoteAddress(req) assert.Empty(t, ip) req.Header.Set(xff, fmt.Sprintf("%v , %v", remoteAddr2, remoteAddr1)) ip = util.GetRealIP(req, xff, 1) assert.Equal(t, remoteAddr2, ip) req.RemoteAddr = remoteAddr2 req.Header.Set(xff, fmt.Sprintf("%v,%v", "12.34.56.78", "172.16.2.4")) server.binding.ClientIPHeaderDepth = 1 server.binding.ClientIPProxyHeader = xff ip = server.checkRemoteAddress(req) assert.Equal(t, "12.34.56.78", ip) assert.Equal(t, ip, req.RemoteAddr) req.RemoteAddr = remoteAddr2 req.Header.Set(xff, fmt.Sprintf("%v,%v", "12.34.56.79", "172.16.2.5")) server.binding.ClientIPHeaderDepth = 0 ip = server.checkRemoteAddress(req) assert.Equal(t, "172.16.2.5", ip) assert.Equal(t, ip, req.RemoteAddr) req.RemoteAddr = "10.8.0.2" req.Header.Set(xff, remoteAddr1) ip = server.checkRemoteAddress(req) assert.Equal(t, remoteAddr1, ip) assert.Equal(t, ip, req.RemoteAddr) req.RemoteAddr = "10.8.0.3" req.Header.Set(xff, "not an ip") ip = server.checkRemoteAddress(req) assert.Equal(t, "10.8.0.3", ip) assert.Equal(t, ip, req.RemoteAddr) req.Header.Del(xff) req.RemoteAddr = "" req.Header.Set(xRealIP, remoteAddr1) ip = util.GetRealIP(req, "x-real-ip", 0) assert.Equal(t, remoteAddr1, ip) req.RemoteAddr = "" } func TestConnWithNilRequest(t *testing.T) { c := &Connection{} assert.Empty(t, c.GetClientVersion()) assert.Empty(t, c.GetCommand()) assert.Empty(t, c.GetRemoteAddress()) assert.True(t, c.getModificationTime().IsZero()) } func TestResolvePathErrors(t *testing.T) { ctx := context.Background() user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: "invalid", }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } err := connection.Mkdir(ctx, "", os.ModePerm) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } err = connection.Rename(ctx, "oldName", "newName") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } _, err = connection.Stat(ctx, "name") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } err = connection.RemoveAll(ctx, "") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } _, err = connection.OpenFile(ctx, "", 0, os.ModePerm) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) } if runtime.GOOS != osWindows { user.HomeDir = filepath.Clean(os.TempDir()) connection.User = user fs := vfs.NewOsFs("connID", connection.User.HomeDir, "", nil) subDir := "sub" testTxtFile := "file.txt" err = os.MkdirAll(filepath.Join(os.TempDir(), subDir, subDir), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(filepath.Join(os.TempDir(), subDir, subDir, testTxtFile), []byte("content"), os.ModePerm) assert.NoError(t, err) err = os.Chmod(filepath.Join(os.TempDir(), subDir, subDir), 0001) assert.NoError(t, err) err = os.WriteFile(filepath.Join(os.TempDir(), testTxtFile), []byte("test content"), os.ModePerm) assert.NoError(t, err) err = connection.Rename(ctx, testTxtFile, path.Join(subDir, subDir, testTxtFile)) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrPermissionDenied.Error()) } _, err = connection.putFile(fs, filepath.Join(connection.User.HomeDir, subDir, subDir, testTxtFile), path.Join(subDir, subDir, testTxtFile)) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrPermissionDenied.Error()) } err = os.Chmod(filepath.Join(os.TempDir(), subDir, subDir), os.ModePerm) assert.NoError(t, err) err = os.RemoveAll(filepath.Join(os.TempDir(), subDir)) assert.NoError(t, err) err = os.Remove(filepath.Join(os.TempDir(), testTxtFile)) assert.NoError(t, err) } } func TestFileAccessErrors(t *testing.T) { ctx := context.Background() user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } missingPath := "missing path" fsMissingPath := filepath.Join(user.HomeDir, missingPath) err := connection.RemoveAll(ctx, missingPath) assert.ErrorIs(t, err, os.ErrNotExist) davFile, err := connection.getFile(fs, fsMissingPath, missingPath) assert.NoError(t, err) buf := make([]byte, 64) _, err = davFile.Read(buf) assert.ErrorIs(t, err, os.ErrNotExist) err = davFile.Close() assert.ErrorIs(t, err, os.ErrNotExist) p := filepath.Join(user.HomeDir, "adir", missingPath) _, err = connection.handleUploadToNewFile(fs, p, p, path.Join("adir", missingPath)) assert.ErrorIs(t, err, os.ErrNotExist) _, err = connection.handleUploadToExistingFile(fs, p, "_"+p, 0, path.Join("adir", missingPath)) if assert.Error(t, err) { assert.ErrorIs(t, err, os.ErrNotExist) } fs = newMockOsFs(false, fs.ConnectionID(), user.HomeDir, nil, nil) _, err = connection.handleUploadToExistingFile(fs, p, p, 0, path.Join("adir", missingPath)) assert.ErrorIs(t, err, os.ErrNotExist) f, err := os.CreateTemp("", "temp") assert.NoError(t, err) err = f.Close() assert.NoError(t, err) davFile, err = connection.handleUploadToExistingFile(fs, f.Name(), f.Name(), 123, f.Name()) if assert.NoError(t, err) { transfer := davFile.(*webDavFile) transfers := connection.GetTransfers() if assert.Equal(t, 1, len(transfers)) { assert.Equal(t, transfers[0].ID, transfer.GetID()) assert.Equal(t, int64(123), transfer.InitialSize) err = transfer.Close() assert.NoError(t, err) assert.Equal(t, 0, len(connection.GetTransfers())) } // test PROPPATCH date parsing error pstats, err := transfer.Patch([]webdav.Proppatch{ { Props: []webdav.Property{ { XMLName: xml.Name{ Space: "DAV", Local: "getlastmodified", }, InnerXML: []byte(`Wid, 04 Nov 2020 13:25:51 GMT`), }, }, }, }) assert.NoError(t, err) for _, pstat := range pstats { assert.Equal(t, http.StatusForbidden, pstat.Status) } err = os.Remove(f.Name()) assert.NoError(t, err) // the file is deleted PROPPATCH should fail pstats, err = transfer.Patch([]webdav.Proppatch{ { Props: []webdav.Property{ { XMLName: xml.Name{ Space: "DAV", Local: "getlastmodified", }, InnerXML: []byte(`Wed, 04 Nov 2020 13:25:51 GMT`), }, }, }, }) assert.NoError(t, err) for _, pstat := range pstats { assert.Equal(t, http.StatusForbidden, pstat.Status) } } } func TestCheckRequestMethodWithPrefix(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Clean(os.TempDir()), Permissions: map[string][]string{ "/": {dataprovider.PermAny}, }, }, } fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } server := webDavServer{ binding: Binding{ Prefix: "/dav", }, } req, err := http.NewRequest(http.MethodGet, "/../dav", nil) require.NoError(t, err) server.checkRequestMethod(context.Background(), req, connection) require.Equal(t, "PROPFIND", req.Method) require.Equal(t, "1", req.Header.Get("Depth")) } func TestContentType(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } testFilePath := filepath.Join(user.HomeDir, testFile) ctx := context.Background() baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".unknown", common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) fs = newMockOsFs(false, fs.ConnectionID(), user.GetHomeDir(), nil, nil) err := os.WriteFile(testFilePath, []byte(""), os.ModePerm) assert.NoError(t, err) davFile := newWebDavFile(baseTransfer, nil, nil) davFile.Fs = fs fi, err := davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(*webDavFileInfo).ContentType(ctx) assert.NoError(t, err) assert.Equal(t, "application/custom-mime", ctype) } _, err = davFile.Readdir(-1) assert.ErrorIs(t, err, webdav.ErrNotImplemented) _, err = davFile.ReadDir() assert.Error(t, err) err = davFile.Close() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".unknown1", common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = vfs.NewOsFs("id", user.HomeDir, "", nil) fi, err = davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(*webDavFileInfo).ContentType(ctx) assert.NoError(t, err) assert.Equal(t, "text/plain; charset=utf-8", ctype) } err = davFile.Close() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = vfs.NewOsFs("id", user.HomeDir, "", nil) fi, err = davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(*webDavFileInfo).ContentType(ctx) assert.NoError(t, err) assert.Equal(t, "application/octet-stream", ctype) } err = davFile.Close() assert.NoError(t, err) for i := 0; i < 2; i++ { // the second time the cache will be used baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".custom", common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = vfs.NewOsFs("id", user.HomeDir, "", nil) fi, err = davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(*webDavFileInfo).ContentType(ctx) assert.NoError(t, err) assert.Equal(t, "text/plain; charset=utf-8", ctype) } err = davFile.Close() assert.NoError(t, err) } baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".sftpgo", common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) fs = newMockOsFs(false, fs.ConnectionID(), user.GetHomeDir(), nil, os.ErrInvalid) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = fs fi, err = davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(*webDavFileInfo).ContentType(ctx) assert.NoError(t, err) assert.Equal(t, "application/sftpgo", ctype) } err = davFile.Close() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".unknown2", common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) fs = newMockOsFs(false, fs.ConnectionID(), user.GetHomeDir(), nil, os.ErrInvalid) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = fs fi, err = davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(*webDavFileInfo).ContentType(ctx) assert.EqualError(t, err, webdav.ErrNotImplemented.Error(), "unexpected content type %q", ctype) } cache := mimeCache{ maxSize: 10, mimeTypes: map[string]string{}, } cache.addMimeToCache("", "") cache.RLock() assert.Len(t, cache.mimeTypes, 0) cache.RUnlock() err = os.Remove(testFilePath) assert.NoError(t, err) } func TestTransferReadWriteErrors(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) connection := &Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } testFilePath := filepath.Join(user.HomeDir, testFile) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile := newWebDavFile(baseTransfer, nil, nil) p := make([]byte, 1) _, err := davFile.Read(p) assert.EqualError(t, err, common.ErrOpUnsupported.Error()) r, w, err := pipeat.Pipe() assert.NoError(t, err) davFile = newWebDavFile(baseTransfer, nil, vfs.NewPipeReader(r)) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) err = r.Close() assert.NoError(t, err) err = w.Close() assert.NoError(t, err) err = davFile.BaseTransfer.Close() assert.Error(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Read(p) assert.True(t, fs.IsNotExist(err)) _, err = davFile.Stat() assert.True(t, fs.IsNotExist(err)) err = davFile.Close() assert.Error(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) err = os.WriteFile(testFilePath, []byte(""), os.ModePerm) assert.NoError(t, err) f, err := os.Open(testFilePath) if assert.NoError(t, err) { err = f.Close() assert.NoError(t, err) } davFile = newWebDavFile(baseTransfer, nil, nil) davFile.reader = f err = davFile.Close() assert.EqualError(t, err, common.ErrGenericFailure.Error()) err = davFile.Close() assert.EqualError(t, err, common.ErrTransferClosed.Error()) _, err = davFile.Read(p) assert.Error(t, err) info, err := davFile.Stat() if assert.NoError(t, err) { assert.Equal(t, int64(0), info.Size()) } err = davFile.Close() assert.Error(t, err) r, w, err = pipeat.Pipe() assert.NoError(t, err) mockFs := newMockOsFs(false, fs.ConnectionID(), user.HomeDir, r, nil) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, mockFs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) writeContent := []byte("content\r\n") go func() { n, err := w.Write(writeContent) assert.NoError(t, err) assert.Equal(t, len(writeContent), n) err = w.Close() assert.NoError(t, err) }() p = make([]byte, 64) n, err := davFile.Read(p) assert.EqualError(t, err, io.EOF.Error()) assert.Equal(t, len(writeContent), n) err = davFile.Close() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.writer = f err = davFile.Close() assert.EqualError(t, err, common.ErrGenericFailure.Error()) err = os.Remove(testFilePath) assert.NoError(t, err) } func TestTransferSeek(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Clean(os.TempDir()), }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := newMockOsFs(true, "connID", user.HomeDir, nil, nil) connection := &Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } testFilePath := filepath.Join(user.HomeDir, testFile) testFileContents := []byte("content") baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile := newWebDavFile(baseTransfer, nil, nil) _, err := davFile.Seek(0, io.SeekStart) assert.EqualError(t, err, common.ErrOpUnsupported.Error()) err = davFile.Close() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekCurrent) assert.True(t, fs.IsNotExist(err)) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) err = os.WriteFile(testFilePath, testFileContents, os.ModePerm) assert.NoError(t, err) f, err := os.Open(testFilePath) if assert.NoError(t, err) { err = f.Close() assert.NoError(t, err) } baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekStart) assert.Error(t, err) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) res, err := davFile.Seek(0, io.SeekStart) assert.NoError(t, err) assert.Equal(t, int64(0), res) err = davFile.Close() assert.NoError(t, err) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile = newWebDavFile(baseTransfer, nil, nil) res, err = davFile.Seek(0, io.SeekEnd) assert.NoError(t, err) assert.Equal(t, int64(len(testFileContents)), res) err = davFile.updateStatInfo() assert.NoError(t, err) err = davFile.Close() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekEnd) assert.True(t, fs.IsNotExist(err)) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) fs = vfs.NewOsFs(fs.ConnectionID(), user.GetHomeDir(), "", nil) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekEnd) assert.True(t, fs.IsNotExist(err)) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.reader = f r, _, err := pipeat.Pipe() assert.NoError(t, err) davFile.Fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), r, nil) res, err = davFile.Seek(2, io.SeekStart) assert.NoError(t, err) assert.Equal(t, int64(2), res) err = davFile.Close() assert.NoError(t, err) r, _, err = pipeat.Pipe() assert.NoError(t, err) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), r, nil) res, err = davFile.Seek(2, io.SeekEnd) assert.NoError(t, err) assert.Equal(t, int64(5), res) err = davFile.Close() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), nil, nil) res, err = davFile.Seek(2, io.SeekEnd) assert.True(t, fs.IsNotExist(err)) assert.Equal(t, int64(0), res) err = davFile.Close() assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(testFilePath) assert.NoError(t, err) } func TestBasicUsersCache(t *testing.T) { username := "webdav_internal_test" password := "pwd" u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: password, HomeDir: filepath.Join(os.TempDir(), username), Status: 1, ExpirationDate: 0, }, } u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} err := dataprovider.AddUser(&u, "", "", "") assert.NoError(t, err) user, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) c := &Configuration{ Bindings: []Binding{ { Port: 9000, }, }, Cache: Cache{ Users: UsersCacheConfig{ MaxSize: 50, ExpirationTime: 1, }, }, } dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) server := webDavServer{ config: c, binding: c.Bindings[0], } req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil) assert.NoError(t, err) ipAddr := "127.0.0.1" _, _, _, _, err = server.authenticate(req, ipAddr) //nolint:dogsled assert.Error(t, err) now := time.Now() req.SetBasicAuth(username, password) _, isCached, _, loginMethod, err := server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) // now the user should be cached cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.False(t, cachedUser.IsExpired()) assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute))) // authenticate must return the cached user now authUser, isCached, _, _, err := server.authenticate(req, ipAddr) assert.NoError(t, err) assert.True(t, isCached) assert.Equal(t, cachedUser.User, authUser) } // a wrong password must fail req.SetBasicAuth(username, "wrong") _, _, _, _, err = server.authenticate(req, ipAddr) //nolint:dogsled assert.EqualError(t, err, dataprovider.ErrInvalidCredentials.Error()) req.SetBasicAuth(username, password) // force cached user expiration cachedUser.Expiration = now dataprovider.CacheWebDAVUser(cachedUser) cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.True(t, cachedUser.IsExpired()) } // now authenticate should get the user from the data provider and update the cache _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.False(t, cachedUser.IsExpired()) } // cache is not invalidated after a user modification if the fs does not change err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) _, ok = dataprovider.GetCachedWebDAVUser(username) assert.True(t, ok) folderName := "testFolder" f := &vfs.BaseVirtualFolder{ Name: folderName, MappedPath: filepath.Join(os.TempDir(), "mapped"), } err = dataprovider.AddFolder(f, "", "", "") assert.NoError(t, err) user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/vdir", }) err = dataprovider.UpdateUser(&user, "", "", "") assert.NoError(t, err) _, ok = dataprovider.GetCachedWebDAVUser(username) assert.False(t, ok) _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) _, ok = dataprovider.GetCachedWebDAVUser(username) assert.True(t, ok) // cache is invalidated after user deletion err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) _, ok = dataprovider.GetCachedWebDAVUser(username) assert.False(t, ok) err = dataprovider.DeleteFolder(folderName, "", "", "") assert.NoError(t, err) err = os.RemoveAll(u.GetHomeDir()) assert.NoError(t, err) } func TestCachedUserWithFolders(t *testing.T) { username := "webdav_internal_folder_test" password := "dav_pwd" folderName := "test_folder" u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: password, HomeDir: filepath.Join(os.TempDir(), username), Status: 1, ExpirationDate: 0, }, } u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/vpath", }) f := &vfs.BaseVirtualFolder{ Name: folderName, MappedPath: filepath.Join(os.TempDir(), folderName), } err := dataprovider.AddFolder(f, "", "", "") assert.NoError(t, err) err = dataprovider.AddUser(&u, "", "", "") assert.NoError(t, err) user, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) c := &Configuration{ Bindings: []Binding{ { Port: 9000, }, }, Cache: Cache{ Users: UsersCacheConfig{ MaxSize: 50, ExpirationTime: 1, }, }, } dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) server := webDavServer{ config: c, binding: c.Bindings[0], } req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil) assert.NoError(t, err) ipAddr := "127.0.0.1" _, _, _, _, err = server.authenticate(req, ipAddr) //nolint:dogsled assert.Error(t, err) now := time.Now() req.SetBasicAuth(username, password) _, isCached, _, loginMethod, err := server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) // now the user should be cached cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.False(t, cachedUser.IsExpired()) assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute))) // authenticate must return the cached user now authUser, isCached, _, _, err := server.authenticate(req, ipAddr) assert.NoError(t, err) assert.True(t, isCached) assert.Equal(t, cachedUser.User, authUser) } folder, err := dataprovider.GetFolderByName(folderName) assert.NoError(t, err) // updating a used folder should invalidate the cache only if the fs changed err = dataprovider.UpdateFolder(&folder, folder.Users, folder.Groups, "", "", "") assert.NoError(t, err) _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.True(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.False(t, cachedUser.IsExpired()) } // changing the folder path should invalidate the cache folder.MappedPath = filepath.Join(os.TempDir(), "anotherpath") err = dataprovider.UpdateFolder(&folder, folder.Users, folder.Groups, "", "", "") assert.NoError(t, err) _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.False(t, cachedUser.IsExpired()) } err = dataprovider.DeleteFolder(folderName, "", "", "") assert.NoError(t, err) // removing a used folder should invalidate the cache _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.False(t, cachedUser.IsExpired()) } err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) _, ok = dataprovider.GetCachedWebDAVUser(username) assert.False(t, ok) err = os.RemoveAll(u.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(folder.MappedPath) assert.NoError(t, err) } func TestUsersCacheSizeAndExpiration(t *testing.T) { username := "webdav_internal_test" password := "pwd" u := dataprovider.User{ BaseUser: sdk.BaseUser{ HomeDir: filepath.Join(os.TempDir(), username), Status: 1, ExpirationDate: 0, }, } u.Username = username + "1" u.Password = password + "1" u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} err := dataprovider.AddUser(&u, "", "", "") assert.NoError(t, err) user1, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) u.Username = username + "2" u.Password = password + "2" err = dataprovider.AddUser(&u, "", "", "") assert.NoError(t, err) user2, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) u.Username = username + "3" u.Password = password + "3" err = dataprovider.AddUser(&u, "", "", "") assert.NoError(t, err) user3, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) u.Username = username + "4" u.Password = password + "4" err = dataprovider.AddUser(&u, "", "", "") assert.NoError(t, err) user4, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) c := &Configuration{ Bindings: []Binding{ { Port: 9000, }, }, Cache: Cache{ Users: UsersCacheConfig{ MaxSize: 3, ExpirationTime: 1, }, }, } dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) server := webDavServer{ config: c, binding: c.Bindings[0], } ipAddr := "127.0.1.1" req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user1.Username, password+"1") _, isCached, _, loginMehod, err := server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user2.Username, password+"2") _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user3.Username, password+"3") _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) // the first 3 users are now cached _, ok := dataprovider.GetCachedWebDAVUser(user1.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) assert.True(t, ok) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user4.Username, password+"4") _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) // user1, the first cached, should be removed now _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) assert.False(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) assert.True(t, ok) // a sleep ensures that expiration times are different time.Sleep(20 * time.Millisecond) // user1 logins, user2 should be removed req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user1.Username, password+"1") _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) assert.False(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) assert.True(t, ok) // a sleep ensures that expiration times are different time.Sleep(20 * time.Millisecond) // user2 logins, user3 should be removed req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user2.Username, password+"2") _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) assert.False(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) assert.True(t, ok) // a sleep ensures that expiration times are different time.Sleep(20 * time.Millisecond) // user3 logins, user4 should be removed req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user3.Username, password+"3") _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) assert.False(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) assert.True(t, ok) // now remove user1 after an update user1.HomeDir += "_mod" err = dataprovider.UpdateUser(&user1, "", "", "") assert.NoError(t, err) _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) assert.False(t, ok) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user4.Username, password+"4") _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) // a sleep ensures that expiration times are different time.Sleep(20 * time.Millisecond) req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) assert.NoError(t, err) req.SetBasicAuth(user1.Username, password+"1") _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) assert.False(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) assert.True(t, ok) _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) assert.True(t, ok) err = dataprovider.DeleteUser(user1.Username, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteUser(user2.Username, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteUser(user3.Username, "", "", "") assert.NoError(t, err) err = dataprovider.DeleteUser(user4.Username, "", "", "") assert.NoError(t, err) err = os.RemoveAll(u.GetHomeDir()) assert.NoError(t, err) } func TestUserCacheIsolation(t *testing.T) { dataprovider.InitializeWebDAVUserCache(10) username := "webdav_internal_cache_test" password := "dav_pwd" u := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: username, Password: password, HomeDir: filepath.Join(os.TempDir(), username), Status: 1, ExpirationDate: 0, }, } u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} err := dataprovider.AddUser(&u, "", "", "") assert.NoError(t, err) user, err := dataprovider.UserExists(u.Username, "") assert.NoError(t, err) cachedUser := &dataprovider.CachedUser{ User: user, Expiration: time.Now().Add(24 * time.Hour), Password: password, LockSystem: webdav.NewMemLS(), } cachedUser.User.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("test secret") cachedUser.User.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("test key") err = cachedUser.User.FsConfig.S3Config.AccessSecret.Encrypt() assert.NoError(t, err) err = cachedUser.User.FsConfig.S3Config.SSECustomerKey.Encrypt() assert.NoError(t, err) dataprovider.CacheWebDAVUser(cachedUser) cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { _, err = cachedUser.User.GetFilesystem("") assert.NoError(t, err) // the filesystem is now cached } cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.True(t, cachedUser.User.FsConfig.S3Config.AccessSecret.IsEncrypted()) err = cachedUser.User.FsConfig.S3Config.AccessSecret.Decrypt() assert.NoError(t, err) assert.True(t, cachedUser.User.FsConfig.S3Config.SSECustomerKey.IsEncrypted()) err = cachedUser.User.FsConfig.S3Config.SSECustomerKey.Decrypt() assert.NoError(t, err) cachedUser.User.FsConfig.Provider = sdk.S3FilesystemProvider _, err = cachedUser.User.GetFilesystem("") assert.Error(t, err, "we don't have to get the previously cached filesystem!") } cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { assert.Equal(t, sdk.LocalFilesystemProvider, cachedUser.User.FsConfig.Provider) assert.False(t, cachedUser.User.FsConfig.S3Config.AccessSecret.IsEncrypted()) assert.False(t, cachedUser.User.FsConfig.S3Config.SSECustomerKey.IsEncrypted()) } err = dataprovider.DeleteUser(username, "", "", "") assert.NoError(t, err) _, ok = dataprovider.GetCachedWebDAVUser(username) assert.False(t, ok) } func TestRecoverer(t *testing.T) { c := &Configuration{ Bindings: []Binding{ { Port: 9000, }, }, } server := webDavServer{ config: c, binding: c.Bindings[0], } rr := httptest.NewRecorder() server.ServeHTTP(rr, nil) assert.Equal(t, http.StatusInternalServerError, rr.Code) } func TestMimeCache(t *testing.T) { cache := mimeCache{ maxSize: 0, mimeTypes: make(map[string]string), } cache.addMimeToCache(".zip", "application/zip") mtype := cache.getMimeFromCache(".zip") assert.Equal(t, "", mtype) cache.maxSize = 1 cache.addMimeToCache(".zip", "application/zip") mtype = cache.getMimeFromCache(".zip") assert.Equal(t, "application/zip", mtype) cache.addMimeToCache(".jpg", "image/jpeg") mtype = cache.getMimeFromCache(".jpg") assert.Equal(t, "", mtype) } func TestVerifyTLSConnection(t *testing.T) { oldCertMgr := certMgr caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") certPath := filepath.Join(os.TempDir(), "test.crt") keyPath := filepath.Join(os.TempDir(), "test.key") err := os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(certPath, []byte(webDavCert), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(keyPath, []byte(webDavKey), os.ModePerm) assert.NoError(t, err) keyPairs := []common.TLSKeyPair{ { Cert: certPath, Key: keyPath, ID: common.DefaultTLSKeyPaidID, }, } certMgr, err = common.NewCertManager(keyPairs, "", "webdav_test") assert.NoError(t, err) certMgr.SetCARevocationLists([]string{caCrlPath}) err = certMgr.LoadCRLs() assert.NoError(t, err) crt, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) x509crt, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) server := webDavServer{} state := tls.ConnectionState{ PeerCertificates: []*x509.Certificate{x509crt}, } err = server.verifyTLSConnection(state) assert.Error(t, err) // no verified certification chain crt, err = tls.X509KeyPair([]byte(caCRT), []byte(caKey)) assert.NoError(t, err) x509CAcrt, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crt, x509CAcrt}) err = server.verifyTLSConnection(state) assert.NoError(t, err) crt, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) assert.NoError(t, err) x509crtRevoked, err := x509.ParseCertificate(crt.Certificate[0]) assert.NoError(t, err) state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crtRevoked, x509CAcrt}) state.PeerCertificates = []*x509.Certificate{x509crtRevoked} err = server.verifyTLSConnection(state) assert.EqualError(t, err, common.ErrCrtRevoked.Error()) err = os.Remove(caCrlPath) assert.NoError(t, err) err = os.Remove(certPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) certMgr = oldCertMgr } func TestMisc(t *testing.T) { oldCertMgr := certMgr certMgr = nil err := ReloadCertificateMgr() assert.Nil(t, err) val := getConfigPath("", ".") assert.Empty(t, val) certMgr = oldCertMgr } func TestParseTime(t *testing.T) { res, err := parseTime("Sat, 4 Feb 2023 17:00:50 GMT") require.NoError(t, err) require.Equal(t, int64(1675530050), res.Unix()) res, err = parseTime("Wed, 04 Nov 2020 13:25:51 GMT") require.NoError(t, err) require.Equal(t, int64(1604496351), res.Unix()) } func TestConfigsFromProvider(t *testing.T) { configDir := "." err := dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) c := Configuration{ Bindings: []Binding{ { Port: 1234, }, }, } err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) configs := dataprovider.Configs{ ACME: &dataprovider.ACMEConfigs{ Domain: "domain.com", Email: "info@domain.com", HTTP01Challenge: dataprovider.ACMEHTTP01Challenge{Port: 80}, Protocols: 7, }, } err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) util.CertsBasePath = "" // crt and key empty err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) util.CertsBasePath = filepath.Clean(os.TempDir()) // crt not found err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs := c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) crtPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".crt") err = os.WriteFile(crtPath, nil, 0666) assert.NoError(t, err) // key not found err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) keyPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".key") err = os.WriteFile(keyPath, nil, 0666) assert.NoError(t, err) // acme cert used err = c.loadFromProvider() assert.NoError(t, err) assert.Equal(t, configs.ACME.Domain, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 1) assert.True(t, c.Bindings[0].EnableHTTPS) // protocols does not match configs.ACME.Protocols = 3 err = dataprovider.UpdateConfigs(&configs, "", "", "") assert.NoError(t, err) c.acmeDomain = "" err = c.loadFromProvider() assert.NoError(t, err) assert.Empty(t, c.acmeDomain) keyPairs = c.getKeyPairs(configDir) assert.Len(t, keyPairs, 0) err = os.Remove(crtPath) assert.NoError(t, err) err = os.Remove(keyPath) assert.NoError(t, err) util.CertsBasePath = "" err = dataprovider.UpdateConfigs(nil, "", "", "") assert.NoError(t, err) } func TestGetCacheExpirationTime(t *testing.T) { c := UsersCacheConfig{} assert.True(t, c.getExpirationTime().IsZero()) c.ExpirationTime = 1 assert.False(t, c.getExpirationTime().IsZero()) } func TestBindingGetAddress(t *testing.T) { tests := []struct { name string binding Binding want string }{ { name: "IP address with port", binding: Binding{Address: "127.0.0.1", Port: 8080}, want: "127.0.0.1:8080", }, { name: "Unix socket path (no port)", binding: Binding{Address: "/tmp/app.sock", Port: 0}, want: "/tmp/app.sock", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.binding.GetAddress(); got != tt.want { t.Errorf("GetAddress() = %v, want %v", got, tt.want) } }) } } func TestBindingIsValid(t *testing.T) { tests := []struct { name string binding Binding want bool }{ { name: "Valid: Positive port", binding: Binding{Address: "127.0.0.1", Port: 10080}, want: true, }, { name: "Valid: Absolute path on Unix (non-Windows)", binding: Binding{Address: "/var/run/app.sock", Port: 0}, // This test outcome is dynamic based on the OS want: runtime.GOOS != osWindows, }, { name: "Invalid: Port 0 and relative path", binding: Binding{Address: "relative/path", Port: 0}, want: false, }, { name: "Invalid: Empty address and port 0", binding: Binding{Address: "", Port: 0}, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.binding.IsValid(); got != tt.want { t.Errorf("IsValid() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: internal/webdavd/mimecache.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package webdavd import "sync" type mimeCache struct { maxSize int sync.RWMutex mimeTypes map[string]string } var ( mimeTypeCache mimeCache customMimeTypeMapping map[string]string ) func (c *mimeCache) addMimeToCache(key, value string) { c.Lock() defer c.Unlock() if key == "" || value == "" { return } if len(c.mimeTypes) >= c.maxSize { return } c.mimeTypes[key] = value } func (c *mimeCache) getMimeFromCache(key string) string { c.RLock() defer c.RUnlock() if val, ok := c.mimeTypes[key]; ok { return val } return "" } ================================================ FILE: internal/webdavd/server.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package webdavd import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "log" "net" "net/http" "path" "path/filepath" "runtime/debug" "slices" "strings" "time" "github.com/drakkan/webdav" "github.com/go-chi/chi/v5/middleware" "github.com/rs/cors" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" ) type webDavServer struct { config *Configuration binding Binding } func (s *webDavServer) listenAndServe(compressor *middleware.Compressor) error { handler := compressor.Handler(s) httpServer := &http.Server{ ReadHeaderTimeout: 30 * time.Second, IdleTimeout: 60 * time.Second, MaxHeaderBytes: 1 << 16, // 64KB ErrorLog: log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0), } if s.config.Cors.Enabled { c := cors.New(cors.Options{ AllowedOrigins: util.RemoveDuplicates(s.config.Cors.AllowedOrigins, true), AllowedMethods: util.RemoveDuplicates(s.config.Cors.AllowedMethods, true), AllowedHeaders: util.RemoveDuplicates(s.config.Cors.AllowedHeaders, true), ExposedHeaders: util.RemoveDuplicates(s.config.Cors.ExposedHeaders, true), MaxAge: s.config.Cors.MaxAge, AllowCredentials: s.config.Cors.AllowCredentials, OptionsPassthrough: s.config.Cors.OptionsPassthrough, OptionsSuccessStatus: s.config.Cors.OptionsSuccessStatus, AllowPrivateNetwork: s.config.Cors.AllowPrivateNetwork, }) handler = c.Handler(handler) } httpServer.Handler = handler if certMgr != nil && s.binding.EnableHTTPS { serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding) certID := common.DefaultTLSKeyPaidID if getConfigPath(s.binding.CertificateFile, "") != "" && getConfigPath(s.binding.CertificateKeyFile, "") != "" { certID = s.binding.GetAddress() } httpServer.TLSConfig = &tls.Config{ GetCertificate: certMgr.GetCertificateFunc(certID), MinVersion: util.GetTLSVersion(s.binding.MinTLSVersion), NextProtos: util.GetALPNProtocols(s.binding.Protocols), CipherSuites: util.GetTLSCiphersFromNames(s.binding.TLSCipherSuites), } logger.Debug(logSender, "", "configured TLS cipher suites for binding %q: %v, certID: %v", s.binding.GetAddress(), httpServer.TLSConfig.CipherSuites, certID) if s.binding.isMutualTLSEnabled() { httpServer.TLSConfig.ClientCAs = certMgr.GetRootCAs() httpServer.TLSConfig.VerifyConnection = s.verifyTLSConnection switch s.binding.ClientAuthType { case 1: httpServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert case 2: httpServer.TLSConfig.ClientAuth = tls.VerifyClientCertIfGiven } } return util.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, true, s.binding.listenerWrapper(), logSender) } s.binding.EnableHTTPS = false serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding) return util.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, false, s.binding.listenerWrapper(), logSender) } func (s *webDavServer) verifyTLSConnection(state tls.ConnectionState) error { if certMgr != nil { var clientCrt *x509.Certificate var clientCrtName string if len(state.PeerCertificates) > 0 { clientCrt = state.PeerCertificates[0] clientCrtName = clientCrt.Subject.String() } if len(state.VerifiedChains) == 0 { if s.binding.ClientAuthType == 2 { return nil } logger.Warn(logSender, "", "TLS connection cannot be verified: unable to get verification chain") return errors.New("TLS connection cannot be verified: unable to get verification chain") } for _, verifiedChain := range state.VerifiedChains { var caCrt *x509.Certificate if len(verifiedChain) > 0 { caCrt = verifiedChain[len(verifiedChain)-1] } if certMgr.IsRevoked(clientCrt, caCrt) { logger.Debug(logSender, "", "tls handshake error, client certificate %q has been revoked", clientCrtName) return common.ErrCrtRevoked } } } return nil } // returns true if we have to handle a HEAD response, for a directory, ourself func (s *webDavServer) checkRequestMethod(ctx context.Context, r *http.Request, connection *Connection) bool { // see RFC4918, section 9.4 if r.Method == http.MethodGet || r.Method == http.MethodHead { p := path.Clean(r.URL.Path) if s.binding.Prefix != "" { p = strings.TrimPrefix(p, s.binding.Prefix) } info, err := connection.Stat(ctx, p) if err == nil && info.IsDir() { if r.Method == http.MethodHead { return true } r.Method = "PROPFIND" if r.Header.Get("Depth") == "" { r.Header.Add("Depth", "1") } } } return false } // ServeHTTP implements the http.Handler interface func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer func() { if r := recover(); r != nil { logger.Error(logSender, "", "panic in ServeHTTP: %q stack trace: %v", r, string(debug.Stack())) http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError) } }() responseControllerDeadlines( http.NewResponseController(w), time.Now().Add(60*time.Second), time.Now().Add(60*time.Second), ) w.Header().Set("Server", version.GetServerVersion("/", false)) ipAddr := s.checkRemoteAddress(r) common.Connections.AddClientConnection(ipAddr) defer common.Connections.RemoveClientConnection(ipAddr) if err := common.Connections.IsNewConnectionAllowed(ipAddr, common.ProtocolWebDAV); err != nil { logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection not allowed from ip %q: %v", ipAddr, err) http.Error(w, err.Error(), http.StatusServiceUnavailable) return } if common.IsBanned(ipAddr, common.ProtocolWebDAV) { http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) return } delay, err := common.LimitRate(common.ProtocolWebDAV, ipAddr) if err != nil { delay += 499999999 * time.Nanosecond w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds())) w.Header().Set("X-Retry-In", delay.String()) http.Error(w, err.Error(), http.StatusTooManyRequests) return } if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolWebDAV); err != nil { http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) return } user, isCached, lockSystem, loginMethod, err := s.authenticate(r, ipAddr) if err != nil { if !s.binding.DisableWWWAuthHeader { w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s WebDAV\"", version.GetServerVersion("_", false))) } http.Error(w, fmt.Sprintf("Authentication error: %v", err), http.StatusUnauthorized) return } connectionID, err := s.validateUser(&user, r, loginMethod) if err != nil { // remove the cached user, we have not yet validated its filesystem dataprovider.RemoveCachedWebDAVUser(user.Username) updateLoginMetrics(&user, ipAddr, loginMethod, err, r) http.Error(w, err.Error(), http.StatusForbidden) return } if !isCached { err = user.CheckFsRoot(connectionID) } else { _, err = user.GetFilesystemForPath("/", connectionID) } if err != nil { errClose := user.CloseFs() logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) updateLoginMetrics(&user, ipAddr, loginMethod, common.ErrInternalFailure, r) http.Error(w, err.Error(), http.StatusInternalServerError) return } baseConn := common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { errClose := user.CloseFs() logger.Warn(logSender, connectionID, "unable add connection: %v close fs error: %v", err, errClose) updateLoginMetrics(&user, ipAddr, loginMethod, err, r) http.Error(w, err.Error(), http.StatusTooManyRequests) return } defer common.Connections.Remove(connection.GetID()) updateLoginMetrics(&user, ipAddr, loginMethod, err, r) ctx := context.WithValue(r.Context(), requestIDKey, connectionID) ctx = context.WithValue(ctx, requestStartKey, time.Now()) dataprovider.UpdateLastLogin(&user) if s.checkRequestMethod(ctx, r, connection) { w.Header().Set("Content-Type", "text/xml; charset=utf-8") w.WriteHeader(http.StatusMultiStatus) w.Write([]byte("")) //nolint:errcheck writeLog(r, http.StatusMultiStatus, nil) return } handler := webdav.Handler{ Prefix: s.binding.Prefix, FileSystem: connection, LockSystem: lockSystem, Logger: writeLog, } handler.ServeHTTP(w, r.WithContext(ctx)) } func (s *webDavServer) getCredentialsAndLoginMethod(r *http.Request) (string, string, string, *x509.Certificate, bool) { var tlsCert *x509.Certificate loginMethod := dataprovider.LoginMethodPassword username, password, ok := r.BasicAuth() if s.binding.isMutualTLSEnabled() && r.TLS != nil { if len(r.TLS.PeerCertificates) > 0 { tlsCert = r.TLS.PeerCertificates[0] if ok { loginMethod = dataprovider.LoginMethodTLSCertificateAndPwd } else { loginMethod = dataprovider.LoginMethodTLSCertificate username = tlsCert.Subject.CommonName password = "" } ok = true } } return username, password, loginMethod, tlsCert, ok } func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.User, bool, webdav.LockSystem, string, error) { var user dataprovider.User var err error username, password, loginMethod, tlsCert, ok := s.getCredentialsAndLoginMethod(r) if !ok { user.Username = username return user, false, nil, loginMethod, common.ErrNoCredentials } cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) if ok { if cachedUser.IsExpired() { dataprovider.RemoveCachedWebDAVUser(username) } else { if !cachedUser.User.IsTLSVerificationEnabled() { // for backward compatibility with 2.0.x we only check the password tlsCert = nil loginMethod = dataprovider.LoginMethodPassword } cu, u, err := dataprovider.CheckCachedUserCredentials(cachedUser, password, ip, loginMethod, common.ProtocolWebDAV, tlsCert) if err == nil { if cu != nil { return cu.User, true, cu.LockSystem, loginMethod, nil } lockSystem := webdav.NewMemLS() cachedUser = &dataprovider.CachedUser{ User: *u, Password: password, LockSystem: lockSystem, Expiration: s.config.Cache.Users.getExpirationTime(), } dataprovider.CacheWebDAVUser(cachedUser) return cachedUser.User, false, cachedUser.LockSystem, loginMethod, nil } updateLoginMetrics(&cachedUser.User, ip, loginMethod, dataprovider.ErrInvalidCredentials, r) return user, false, nil, loginMethod, dataprovider.ErrInvalidCredentials } } user, loginMethod, err = dataprovider.CheckCompositeCredentials(username, password, ip, loginMethod, common.ProtocolWebDAV, tlsCert) if err != nil { user.Username = username updateLoginMetrics(&user, ip, loginMethod, err, r) return user, false, nil, loginMethod, dataprovider.ErrInvalidCredentials } lockSystem := webdav.NewMemLS() cachedUser = &dataprovider.CachedUser{ User: user, Password: password, LockSystem: lockSystem, Expiration: s.config.Cache.Users.getExpirationTime(), } dataprovider.CacheWebDAVUser(cachedUser) return user, false, lockSystem, loginMethod, nil } func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, loginMethod string) (string, error) { connID := xid.New().String() connectionID := fmt.Sprintf("%v_%v", common.ProtocolWebDAV, connID) if !filepath.IsAbs(user.HomeDir) { logger.Warn(logSender, connectionID, "user %q has an invalid home dir: %q. Home dir must be an absolute path, login not allowed", user.Username, user.HomeDir) return connID, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir) } if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolWebDAV) { logger.Info(logSender, connectionID, "cannot login user %q, protocol DAV is not allowed", user.Username) return connID, fmt.Errorf("protocol DAV is not allowed for user %q", user.Username) } if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolWebDAV) { logger.Info(logSender, connectionID, "cannot login user %q, %v login method is not allowed", user.Username, loginMethod) return connID, fmt.Errorf("login method %v is not allowed for user %q", loginMethod, user.Username) } if !user.IsLoginFromAddrAllowed(r.RemoteAddr) { logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", user.Username, r.RemoteAddr) return connID, fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, r.RemoteAddr) } return connID, nil } func (s *webDavServer) checkRemoteAddress(r *http.Request) string { ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) var ip net.IP isUnixSocket := filepath.IsAbs(s.binding.Address) if !isUnixSocket { ip = net.ParseIP(ipAddr) } if isUnixSocket || ip != nil { for _, allow := range s.binding.allowHeadersFrom { if allow(ip) { parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth) if parsedIP != "" { ipAddr = parsedIP r.RemoteAddr = ipAddr } break } } } return ipAddr } func responseControllerDeadlines(rc *http.ResponseController, read, write time.Time) { if err := rc.SetReadDeadline(read); err != nil { logger.Error(logSender, "", "unable to set read timeout to %s: %v", read, err) } if err := rc.SetWriteDeadline(write); err != nil { logger.Error(logSender, "", "unable to set write timeout to %s: %v", write, err) } } func writeLog(r *http.Request, status int, err error) { scheme := "http" cipherSuite := "" if r.TLS != nil { scheme = "https" cipherSuite = tls.CipherSuiteName(r.TLS.CipherSuite) } fields := map[string]any{ "remote_addr": r.RemoteAddr, "proto": r.Proto, "method": r.Method, "user_agent": r.UserAgent(), "uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI), "cipher_suite": cipherSuite, } if reqID, ok := r.Context().Value(requestIDKey).(string); ok { fields["request_id"] = reqID } if reqStart, ok := r.Context().Value(requestStartKey).(time.Time); ok { fields["elapsed_ms"] = time.Since(reqStart).Nanoseconds() / 1000000 } if depth := r.Header.Get("Depth"); depth != "" { fields["depth"] = depth } if contentLength := r.Header.Get("Content-Length"); contentLength != "" { fields["content_length"] = contentLength } if timeout := r.Header.Get("Timeout"); timeout != "" { fields["timeout"] = timeout } if status != 0 { fields["resp_status"] = status } var ev *zerolog.Event if status >= http.StatusInternalServerError { ev = logger.GetLogger().Error() } else if status >= http.StatusBadRequest { ev = logger.GetLogger().Warn() } else { ev = logger.GetLogger().Debug() } ev. Timestamp(). Str("sender", logSender). Fields(fields). Err(err). Send() } func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error, r *http.Request) { metric.AddLoginAttempt(loginMethod) if err == nil { logger.LoginLog(user.Username, ip, loginMethod, common.ProtocolWebDAV, "", r.UserAgent(), r.TLS != nil, "") plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, common.ProtocolWebDAV, user.Username, ip, "", nil) common.DelayLogin(nil) } else if err != common.ErrInternalFailure && err != common.ErrNoCredentials { logger.ConnectionFailedLog(user.Username, ip, loginMethod, common.ProtocolWebDAV, err.Error()) event := common.HostEventLoginFailed logEv := notifier.LogEventTypeLoginFailed if errors.Is(err, util.ErrNotFound) { event = common.HostEventUserNotFound logEv = notifier.LogEventTypeLoginNoUser } common.AddDefenderEvent(ip, common.ProtocolWebDAV, event) plugin.Handler.NotifyLogEvent(logEv, common.ProtocolWebDAV, user.Username, ip, "", err) if loginMethod != dataprovider.LoginMethodTLSCertificate { common.DelayLogin(err) } } metric.AddLoginResult(loginMethod, err) dataprovider.ExecutePostLoginHook(user, loginMethod, ip, common.ProtocolWebDAV, err) } ================================================ FILE: internal/webdavd/webdavd.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Package webdavd implements the WebDAV protocol package webdavd import ( "fmt" "net" "net/http" "os" "path/filepath" "runtime" "time" "github.com/go-chi/chi/v5/middleware" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) type ctxReqParams int const ( requestIDKey ctxReqParams = iota requestStartKey ) const ( logSender = "webdavd" ) var ( certMgr *common.CertManager serviceStatus ServiceStatus timeFormats = []string{ http.TimeFormat, "Mon, _2 Jan 2006 15:04:05 GMT", time.RFC850, time.ANSIC, } ) // ServiceStatus defines the service status type ServiceStatus struct { IsActive bool `json:"is_active"` Bindings []Binding `json:"bindings"` } // CorsConfig defines the CORS configuration type CorsConfig struct { AllowedOrigins []string `json:"allowed_origins" mapstructure:"allowed_origins"` AllowedMethods []string `json:"allowed_methods" mapstructure:"allowed_methods"` AllowedHeaders []string `json:"allowed_headers" mapstructure:"allowed_headers"` ExposedHeaders []string `json:"exposed_headers" mapstructure:"exposed_headers"` AllowCredentials bool `json:"allow_credentials" mapstructure:"allow_credentials"` Enabled bool `json:"enabled" mapstructure:"enabled"` MaxAge int `json:"max_age" mapstructure:"max_age"` OptionsPassthrough bool `json:"options_passthrough" mapstructure:"options_passthrough"` OptionsSuccessStatus int `json:"options_success_status" mapstructure:"options_success_status"` AllowPrivateNetwork bool `json:"allow_private_network" mapstructure:"allow_private_network"` } // CustomMimeMapping defines additional, user defined mime mappings type CustomMimeMapping struct { Ext string `json:"ext" mapstructure:"ext"` Mime string `json:"mime" mapstructure:"mime"` } // UsersCacheConfig defines the cache configuration for users type UsersCacheConfig struct { ExpirationTime int `json:"expiration_time" mapstructure:"expiration_time"` MaxSize int `json:"max_size" mapstructure:"max_size"` } func (c *UsersCacheConfig) getExpirationTime() time.Time { if c.ExpirationTime > 0 { return time.Now().Add(time.Duration(c.ExpirationTime) * time.Minute) } return time.Time{} } // MimeCacheConfig defines the cache configuration for mime types type MimeCacheConfig struct { Enabled bool `json:"enabled" mapstructure:"enabled"` MaxSize int `json:"max_size" mapstructure:"max_size"` CustomMappings []CustomMimeMapping `json:"custom_mappings" mapstructure:"custom_mappings"` } // Cache configuration type Cache struct { Users UsersCacheConfig `json:"users" mapstructure:"users"` MimeTypes MimeCacheConfig `json:"mime_types" mapstructure:"mime_types"` } // Binding defines the configuration for a network listener type Binding struct { // The address to listen on. A blank value means listen on all available network interfaces. Address string `json:"address" mapstructure:"address"` // The port used for serving requests Port int `json:"port" mapstructure:"port"` // you also need to provide a certificate for enabling HTTPS EnableHTTPS bool `json:"enable_https" mapstructure:"enable_https"` // Certificate and matching private key for this specific binding, if empty the global // ones will be used, if any CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` // Defines the minimum TLS version. 13 means TLS 1.3, default is TLS 1.2 MinTLSVersion int `json:"min_tls_version" mapstructure:"min_tls_version"` // set to 1 to require client certificate authentication in addition to basic auth. // You need to define at least a certificate authority for this to work ClientAuthType int `json:"client_auth_type" mapstructure:"client_auth_type"` // TLSCipherSuites is a list of supported cipher suites for TLS version 1.2. // If CipherSuites is nil/empty, a default list of secure cipher suites // is used, with a preference order based on hardware performance. // Note that TLS 1.3 ciphersuites are not configurable. // The supported ciphersuites names are defined here: // // https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53 // // any invalid name will be silently ignored. // The order matters, the ciphers listed first will be the preferred ones. TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` // HTTP protocols to enable in preference order. Supported values: http/1.1, h2 Protocols []string `json:"tls_protocols" mapstructure:"tls_protocols"` // Prefix for WebDAV resources, if empty WebDAV resources will be available at the // root ("/") URI. If defined it must be an absolute URI. Prefix string `json:"prefix" mapstructure:"prefix"` // Defines whether to use the common proxy protocol configuration or the // binding-specific proxy header configuration. ProxyMode int `json:"proxy_mode" mapstructure:"proxy_mode"` // List of IP addresses and IP ranges allowed to set client IP proxy headers ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"` // Allowed client IP proxy header such as "X-Forwarded-For", "X-Real-IP" ClientIPProxyHeader string `json:"client_ip_proxy_header" mapstructure:"client_ip_proxy_header"` // Some client IP headers such as "X-Forwarded-For" can contain multiple IP address, this setting // define the position to trust starting from the right. For example if we have: // "10.0.0.1,11.0.0.1,12.0.0.1,13.0.0.1" and the depth is 0, SFTPGo will use "13.0.0.1" // as client IP, if depth is 1, "12.0.0.1" will be used and so on ClientIPHeaderDepth int `json:"client_ip_header_depth" mapstructure:"client_ip_header_depth"` // Do not add the WWW-Authenticate header after an authentication error, // only the 401 status code will be sent DisableWWWAuthHeader bool `json:"disable_www_auth_header" mapstructure:"disable_www_auth_header"` allowHeadersFrom []func(net.IP) bool } func (b *Binding) parseAllowedProxy() error { if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 { // unix domain socket b.allowHeadersFrom = []func(net.IP) bool{func(_ net.IP) bool { return true }} return nil } allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed) if err != nil { return err } b.allowHeadersFrom = allowedFuncs return nil } func (b *Binding) isMutualTLSEnabled() bool { return b.ClientAuthType == 1 || b.ClientAuthType == 2 } // GetAddress returns the binding address func (b *Binding) GetAddress() string { if b.Port > 0 { return fmt.Sprintf("%s:%d", b.Address, b.Port) } return b.Address } // IsValid returns true if the binding is valid func (b *Binding) IsValid() bool { if b.Port > 0 { return true } if filepath.IsAbs(b.Address) && runtime.GOOS != "windows" { return true } return false } func (b *Binding) listenerWrapper() func(net.Listener) (net.Listener, error) { if b.ProxyMode == 1 { return common.Config.GetProxyListener } return nil } // Configuration defines the configuration for the WevDAV server type Configuration struct { // Addresses and ports to bind to Bindings []Binding `json:"bindings" mapstructure:"bindings"` // If files containing a certificate and matching private key for the server are provided you // can enable HTTPS connections for the configured bindings // Certificate and key files can be reloaded on demand sending a "SIGHUP" signal on Unix based systems and a // "paramchange" request to the running service on Windows. CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` // CACertificates defines the set of root certificate authorities to be used to verify client certificates. CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"` // CARevocationLists defines a set a revocation lists, one for each root CA, to be used to check // if a client certificate has been revoked CARevocationLists []string `json:"ca_revocation_lists" mapstructure:"ca_revocation_lists"` // CORS configuration Cors CorsConfig `json:"cors" mapstructure:"cors"` // Cache configuration Cache Cache `json:"cache" mapstructure:"cache"` acmeDomain string } // GetStatus returns the server status func GetStatus() ServiceStatus { return serviceStatus } // ShouldBind returns true if there is at least a valid binding func (c *Configuration) ShouldBind() bool { for _, binding := range c.Bindings { if binding.IsValid() { return true } } return false } func (c *Configuration) getKeyPairs(configDir string) []common.TLSKeyPair { var keyPairs []common.TLSKeyPair for _, binding := range c.Bindings { certificateFile := getConfigPath(binding.CertificateFile, configDir) certificateKeyFile := getConfigPath(binding.CertificateKeyFile, configDir) if certificateFile != "" && certificateKeyFile != "" { keyPairs = append(keyPairs, common.TLSKeyPair{ Cert: certificateFile, Key: certificateKeyFile, ID: binding.GetAddress(), }) } } var certificateFile, certificateKeyFile string if c.acmeDomain != "" { certificateFile, certificateKeyFile = util.GetACMECertificateKeyPair(c.acmeDomain) } else { certificateFile = getConfigPath(c.CertificateFile, configDir) certificateKeyFile = getConfigPath(c.CertificateKeyFile, configDir) } if certificateFile != "" && certificateKeyFile != "" { keyPairs = append(keyPairs, common.TLSKeyPair{ Cert: certificateFile, Key: certificateKeyFile, ID: common.DefaultTLSKeyPaidID, }) } return keyPairs } func (c *Configuration) loadFromProvider() error { configs, err := dataprovider.GetConfigs() if err != nil { return fmt.Errorf("unable to load config from provider: %w", err) } configs.SetNilsToEmpty() if configs.ACME.Domain == "" || !configs.ACME.HasProtocol(common.ProtocolWebDAV) { return nil } crt, key := util.GetACMECertificateKeyPair(configs.ACME.Domain) if crt != "" && key != "" { if _, err := os.Stat(crt); err != nil { logger.Error(logSender, "", "unable to load acme cert file %q: %v", crt, err) return nil } if _, err := os.Stat(key); err != nil { logger.Error(logSender, "", "unable to load acme key file %q: %v", key, err) return nil } for idx := range c.Bindings { c.Bindings[idx].EnableHTTPS = true } c.acmeDomain = configs.ACME.Domain logger.Info(logSender, "", "acme domain set to %q", c.acmeDomain) return nil } return nil } // Initialize configures and starts the WebDAV server func (c *Configuration) Initialize(configDir string) error { if err := c.loadFromProvider(); err != nil { return err } logger.Info(logSender, "", "initializing WebDAV server with config %+v", *c) mimeTypeCache = mimeCache{ maxSize: c.Cache.MimeTypes.MaxSize, mimeTypes: make(map[string]string), } if !c.Cache.MimeTypes.Enabled { mimeTypeCache.maxSize = 0 } else { customMimeTypeMapping = make(map[string]string) for _, m := range c.Cache.MimeTypes.CustomMappings { if m.Mime != "" { logger.Debug(logSender, "", "adding custom mime mapping for extension %q, mime type %q", m.Ext, m.Mime) customMimeTypeMapping[m.Ext] = m.Mime } } } if !c.ShouldBind() { return common.ErrNoBinding } keyPairs := c.getKeyPairs(configDir) if len(keyPairs) > 0 { mgr, err := common.NewCertManager(keyPairs, configDir, logSender) if err != nil { return err } mgr.SetCACertificates(c.CACertificates) if err := mgr.LoadRootCAs(); err != nil { return err } mgr.SetCARevocationLists(c.CARevocationLists) if err := mgr.LoadCRLs(); err != nil { return err } certMgr = mgr } compressor := middleware.NewCompressor(5, "text/*") dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) serviceStatus = ServiceStatus{ Bindings: nil, } exitChannel := make(chan error, 1) for _, binding := range c.Bindings { if !binding.IsValid() { continue } if err := binding.parseAllowedProxy(); err != nil { return err } go func(binding Binding) { server := webDavServer{ config: c, binding: binding, } exitChannel <- server.listenAndServe(compressor) }(binding) } serviceStatus.IsActive = true return <-exitChannel } // ReloadCertificateMgr reloads the certificate manager func ReloadCertificateMgr() error { if certMgr != nil { return certMgr.Reload() } return nil } func getConfigPath(name, configDir string) string { if !util.IsFileInputValid(name) { return "" } if name != "" && !filepath.IsAbs(name) { return filepath.Join(configDir, name) } return name } func parseTime(text string) (t time.Time, err error) { for _, layout := range timeFormats { t, err = time.Parse(layout, text) if err == nil { return } } return } ================================================ FILE: internal/webdavd/webdavd_test.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package webdavd_test import ( "bufio" "bytes" "crypto/rand" "crypto/tls" "encoding/json" "errors" "fmt" "io" "io/fs" "net" "net/http" "os" "os/exec" "path" "path/filepath" "regexp" "runtime" "strings" "sync" "testing" "time" "github.com/minio/sio" "github.com/pkg/sftp" "github.com/rs/zerolog" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" "github.com/stretchr/testify/assert" "github.com/studio-b12/gowebdav" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/internal/httpdtest" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/internal/webdavd" ) const ( logSender = "webavdTesting" webDavServerAddr = "localhost:9090" webDavTLSServerAddr = "localhost:9443" webDavServerPort = 9090 webDavTLSServerPort = 9443 sftpServerAddr = "127.0.0.1:9022" defaultUsername = "test_user_dav" defaultPassword = "test_password" osWindows = "windows" webDavCert = `-----BEGIN CERTIFICATE----- MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM 3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY /8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI dV4vKmHUzwK/eIx+8Ay3neE= -----END CERTIFICATE-----` webDavKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= -----END EC PRIVATE KEY-----` caCRT = `-----BEGIN CERTIFICATE----- MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U 4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY /tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w kK8u1YiiVenmoQ== -----END CERTIFICATE-----` caCRL = `-----BEGIN X509 CRL----- MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV +jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn w7bXJGfJcIMrrKs= -----END X509 CRL-----` client1Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au MU3Bo0A= -----END CERTIFICATE-----` client1Key = `-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi 85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== -----END RSA PRIVATE KEY-----` // client 2 crt is revoked client2Crt = `-----BEGIN CERTIFICATE----- MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB 4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO 78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv /ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR E0+r2+4= -----END CERTIFICATE-----` client2Key = `-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu 4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl /owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz 3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D 8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl -----END RSA PRIVATE KEY-----` testFileName = "test_file_dav.dat" testDLFileName = "test_download_dav.dat" tlsClient1Username = "client1" tlsClient2Username = "client2" emptyPwdPlaceholder = "empty" ocMtimeHeader = "X-OC-Mtime" ) var ( configDir = filepath.Join(".", "..", "..") allPerms = []string{dataprovider.PermAny} homeBasePath string hookCmdPath string extAuthPath string preLoginPath string postConnectPath string preDownloadPath string preUploadPath string logFilePath string certPath string keyPath string caCrtPath string caCRLPath string ) func TestMain(m *testing.M) { logFilePath = filepath.Join(configDir, "sftpgo_webdavd_test.log") logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel) os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__EXT", ".sftpgo") os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__MIME", "application/sftpgo") err := config.LoadConfig(configDir, "") if err != nil { logger.ErrorToConsole("error loading configuration: %v", err) os.Exit(1) } providerConf := config.GetProviderConf() logger.InfoToConsole("Starting WebDAVD tests, provider: %v", providerConf.Driver) commonConf := config.GetCommonConfig() commonConf.UploadMode = 2 homeBasePath = os.TempDir() if runtime.GOOS != osWindows { commonConf.Actions.ExecuteOn = []string{"download", "upload", "rename", "delete"} commonConf.Actions.Hook = hookCmdPath hookCmdPath, err = exec.LookPath("true") if err != nil { logger.Warn(logSender, "", "unable to get hook command: %v", err) logger.WarnToConsole("unable to get hook command: %v", err) } } certPath = filepath.Join(os.TempDir(), "test_dav.crt") keyPath = filepath.Join(os.TempDir(), "test_dav.key") caCrtPath = filepath.Join(os.TempDir(), "test_dav_ca.crt") caCRLPath = filepath.Join(os.TempDir(), "test_dav_crl.crt") err = os.WriteFile(certPath, []byte(webDavCert), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing WebDAV certificate: %v", err) os.Exit(1) } err = os.WriteFile(keyPath, []byte(webDavKey), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing WebDAV private key: %v", err) os.Exit(1) } err = os.WriteFile(caCrtPath, []byte(caCRT), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing WebDAV CA crt: %v", err) os.Exit(1) } err = os.WriteFile(caCRLPath, []byte(caCRL), os.ModePerm) if err != nil { logger.ErrorToConsole("error writing WebDAV CRL: %v", err) os.Exit(1) } err = dataprovider.Initialize(providerConf, configDir, true) if err != nil { logger.ErrorToConsole("error initializing data provider: %v", err) os.Exit(1) } err = common.Initialize(commonConf, 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) } httpConfig := config.GetHTTPConfig() httpConfig.Initialize(configDir) //nolint:errcheck kmsConfig := config.GetKMSConfig() err = kmsConfig.Initialize() if err != nil { logger.ErrorToConsole("error initializing kms: %v", err) os.Exit(1) } httpdConf := config.GetHTTPDConfig() httpdConf.Bindings[0].Port = 8078 httpdtest.SetBaseURL("http://127.0.0.1:8078") // required to test sftpfs sftpdConf := config.GetSFTPDConfig() sftpdConf.Bindings = []sftpd.Binding{ { Port: 9022, }, } hostKeyPath := filepath.Join(os.TempDir(), "id_ecdsa") sftpdConf.HostKeys = []string{hostKeyPath} webDavConf := config.GetWebDAVDConfig() webDavConf.CACertificates = []string{caCrtPath} webDavConf.CARevocationLists = []string{caCRLPath} webDavConf.Bindings = []webdavd.Binding{ { Port: webDavServerPort, }, { Port: webDavTLSServerPort, EnableHTTPS: true, CertificateFile: certPath, CertificateKeyFile: keyPath, ClientAuthType: 2, }, } webDavConf.Cors = webdavd.CorsConfig{ Enabled: true, AllowedOrigins: []string{"*"}, AllowedMethods: []string{ http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, }, AllowedHeaders: []string{"*"}, AllowCredentials: true, } status := webdavd.GetStatus() if status.IsActive { logger.ErrorToConsole("webdav server is already active") os.Exit(1) } extAuthPath = filepath.Join(homeBasePath, "extauth.sh") preLoginPath = filepath.Join(homeBasePath, "prelogin.sh") postConnectPath = filepath.Join(homeBasePath, "postconnect.sh") preDownloadPath = filepath.Join(homeBasePath, "predownload.sh") preUploadPath = filepath.Join(homeBasePath, "preupload.sh") go func() { logger.Debug(logSender, "", "initializing WebDAV server with config %+v", webDavConf) if err := webDavConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start WebDAV server: %v", err) os.Exit(1) } }() go func() { if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } }() go func() { logger.Debug(logSender, "", "initializing SFTP server with config %+v", sftpdConf) if err := sftpdConf.Initialize(configDir); err != nil { logger.ErrorToConsole("could not start SFTP server: %v", err) os.Exit(1) } }() waitTCPListening(webDavConf.Bindings[0].GetAddress()) waitTCPListening(webDavConf.Bindings[1].GetAddress()) waitTCPListening(httpdConf.Bindings[0].GetAddress()) waitTCPListening(sftpdConf.Bindings[0].GetAddress()) webdavd.ReloadCertificateMgr() //nolint:errcheck exitCode := m.Run() os.Remove(logFilePath) os.Remove(extAuthPath) os.Remove(preLoginPath) os.Remove(postConnectPath) os.Remove(preDownloadPath) os.Remove(preUploadPath) os.Remove(certPath) os.Remove(keyPath) os.Remove(caCrtPath) os.Remove(caCRLPath) os.Remove(hostKeyPath) os.Remove(hostKeyPath + ".pub") os.Exit(exitCode) } func TestInitialization(t *testing.T) { cfg := webdavd.Configuration{ Bindings: []webdavd.Binding{ { Port: 1234, EnableHTTPS: true, }, { Port: 0, }, }, CertificateFile: "missing path", CertificateKeyFile: "bad path", } err := cfg.Initialize(configDir) assert.Error(t, err) cfg.Cache = config.GetWebDAVDConfig().Cache cfg.Bindings[0].Port = webDavServerPort cfg.CertificateFile = certPath cfg.CertificateKeyFile = keyPath err = cfg.Initialize(configDir) assert.Error(t, err) err = webdavd.ReloadCertificateMgr() assert.NoError(t, err) cfg.Bindings = []webdavd.Binding{ { Port: 0, }, } err = cfg.Initialize(configDir) assert.EqualError(t, err, common.ErrNoBinding.Error()) cfg.CertificateFile = certPath cfg.CertificateKeyFile = keyPath cfg.CACertificates = []string{""} cfg.Bindings = []webdavd.Binding{ { Port: 9022, ClientAuthType: 1, EnableHTTPS: true, }, } err = cfg.Initialize(configDir) assert.Error(t, err) cfg.CACertificates = nil cfg.CARevocationLists = []string{""} err = cfg.Initialize(configDir) assert.Error(t, err) cfg.CARevocationLists = nil err = cfg.Initialize(configDir) assert.Error(t, err) cfg.CertificateFile = certPath cfg.CertificateKeyFile = keyPath cfg.CACertificates = []string{caCrtPath} cfg.CARevocationLists = []string{caCRLPath} cfg.Bindings[0].ProxyAllowed = []string{"not valid"} err = cfg.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is not a valid IP address") } cfg.Bindings[0].ProxyAllowed = nil err = cfg.Initialize(configDir) assert.Error(t, err) err = dataprovider.Close() assert.NoError(t, err) err = cfg.Initialize(configDir) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to load config from provider") } err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) } func TestBasicHandling(t *testing.T) { u := getTestUser() u.QuotaSize = 6553600 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaSize = 6553600 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) expectedQuotaSize := testFileSize expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), user.FirstUpload) assert.Equal(t, int64(0), user.FirstDownload) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Greater(t, user.FirstUpload, int64(0)) assert.Greater(t, user.FirstDownload, int64(0)) // webdav read the mime type // overwrite an existing file err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) // wrong password err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword+"1", true, testFileSize, client) assert.Error(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) assert.Greater(t, user.FirstUpload, int64(0)) assert.Greater(t, user.FirstDownload, int64(0)) err = client.Rename(testFileName, testFileName+"1", false) assert.NoError(t, err) _, err = client.Stat(testFileName) assert.Error(t, err) // the webdav client hide the error we check the quota err = client.Remove(testFileName) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) err = client.Remove(testFileName + "1") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize-testFileSize, user.UsedQuotaSize) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err) testDir := "testdir" err = client.Mkdir(testDir, os.ModePerm) assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "sub", "sub"), os.ModePerm) assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "sub1", "sub1"), os.ModePerm) assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "sub2", "sub2"), os.ModePerm) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName+".txt"), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) files, err := client.ReadDir(testDir) assert.NoError(t, err) assert.Len(t, files, 5) err = client.Copy(testDir, testDir+"_copy", false) //nolint:goconst assert.NoError(t, err) err = client.RemoveAll(testDir) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) status := webdavd.GetStatus() assert.True(t, status.IsActive) } func TestBasicHandlingCryptFs(t *testing.T) { u := getTestUserWithCryptFs() u.QuotaSize = 6553600 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) encryptedFileSize, err := getEncryptedFileSize(testFileSize) assert.NoError(t, err) expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize expectedQuotaFiles := user.UsedQuotaFiles + 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) // overwrite an existing file err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) files, err := client.ReadDir("/") assert.NoError(t, err) if assert.Len(t, files, 1) { assert.Equal(t, testFileSize, files[0].Size()) } err = client.Remove(testFileName) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize-encryptedFileSize, user.UsedQuotaSize) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err) testDir := "testdir" err = client.Mkdir(testDir, os.ModePerm) assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "sub", "sub"), os.ModePerm) assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "sub1", "sub1"), os.ModePerm) assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "sub2", "sub2"), os.ModePerm) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName+".txt"), user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName), user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) files, err = client.ReadDir(testDir) assert.NoError(t, err) assert.Len(t, files, 5) for _, f := range files { if strings.HasPrefix(f.Name(), testFileName) { assert.Equal(t, testFileSize, f.Size()) } else { assert.True(t, f.IsDir()) } } err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestBufferedUser(t *testing.T) { u := getTestUser() u.FsConfig.OSConfig = sdk.OSFsConfig{ WriteBufferSize: 2, ReadBufferSize: 1, } vdirPath := "/crypted" mappedPath := filepath.Join(os.TempDir(), util.GenerateUniqueID()) folderName := filepath.Base(mappedPath) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, QuotaFiles: -1, QuotaSize: -1, }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ OSFsConfig: sdk.OSFsConfig{ WriteBufferSize: 3, ReadBufferSize: 2, }, Passphrase: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(vdirPath, testFileName), user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestLoginEmptyPassword(t *testing.T) { u := getTestUser() u.Password = "" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.Password = emptyPwdPlaceholder client := getWebDavClient(user, false, nil) err = checkBasicFunc(client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "401") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestAnonymousUser(t *testing.T) { u := getTestUser() u.Password = "" u.Filters.IsAnonymous = true _, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.Error(t, err) user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) user.Password = emptyPwdPlaceholder client = getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "403") } err = client.Mkdir("testdir", os.ModePerm) if assert.Error(t, err) { assert.Contains(t, err.Error(), "403") } err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLockAfterDelete(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) lockBody := `` req, err := http.NewRequest("LOCK", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(lockBody))) assert.NoError(t, err) req.SetBasicAuth(u.Username, u.Password) req.Header.Set("Timeout", "Second-3600") httpClient := httpclient.GetHTTPClient() resp, err := httpClient.Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) response, err := io.ReadAll(resp.Body) assert.NoError(t, err) re := regexp.MustCompile(`\.*`) lockToken := string(re.Find(response)) lockToken = strings.Replace(lockToken, "", "", 1) lockToken = strings.Replace(lockToken, "", "", 1) err = resp.Body.Close() assert.NoError(t, err) req, err = http.NewRequest(http.MethodDelete, fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), nil) assert.NoError(t, err) req.Header.Set("If", fmt.Sprintf("(%v)", lockToken)) req.SetBasicAuth(u.Username, u.Password) resp, err = httpClient.Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusNoContent, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) // if we try to lock again it must succeed, the lock must be deleted with the object req, err = http.NewRequest("LOCK", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(lockBody))) assert.NoError(t, err) req.SetBasicAuth(u.Username, u.Password) resp, err = httpClient.Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusCreated, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMtimeHeader(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client, dataprovider.KeyValue{Key: ocMtimeHeader, Value: "1668879480"}) assert.NoError(t, err) // check the modification time info, err := client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, time.Unix(1668879480, 0).UTC(), info.ModTime().UTC()) } // test on overwrite err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client, dataprovider.KeyValue{Key: ocMtimeHeader, Value: "1667879480"}) assert.NoError(t, err) info, err = client.Stat(testFileName) if assert.NoError(t, err) { assert.Equal(t, time.Unix(1667879480, 0).UTC(), info.ModTime().UTC()) } // invalid time will be silently ignored and the time set to now err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client, dataprovider.KeyValue{Key: ocMtimeHeader, Value: "not unix time"}) assert.NoError(t, err) info, err = client.Stat(testFileName) if assert.NoError(t, err) { assert.NotEqual(t, time.Unix(1667879480, 0).UTC(), info.ModTime().UTC()) } req, err := http.NewRequest("MOVE", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), nil) assert.NoError(t, err) req.Header.Set("Overwrite", "T") req.Header.Set("Destination", path.Join("/", testFileName+"rename")) req.Header.Set(ocMtimeHeader, "1666779480") req.SetBasicAuth(u.Username, u.Password) httpClient := httpclient.GetHTTPClient() resp, err := httpClient.Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusCreated, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) // check the modification time info, err = client.Stat(testFileName + "rename") if assert.NoError(t, err) { assert.Equal(t, time.Unix(1666779480, 0).UTC(), info.ModTime().UTC()) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestRenameWithLock(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) lockBody := `` req, err := http.NewRequest("LOCK", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(lockBody))) assert.NoError(t, err) req.SetBasicAuth(u.Username, u.Password) httpClient := httpclient.GetHTTPClient() resp, err := httpClient.Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) response, err := io.ReadAll(resp.Body) assert.NoError(t, err) re := regexp.MustCompile(`\.*`) lockToken := string(re.Find(response)) lockToken = strings.Replace(lockToken, "", "", 1) lockToken = strings.Replace(lockToken, "", "", 1) err = resp.Body.Close() assert.NoError(t, err) // MOVE with a lock should succeeded req, err = http.NewRequest("MOVE", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), nil) assert.NoError(t, err) req.Header.Set("If", fmt.Sprintf("(%v)", lockToken)) req.Header.Set("Overwrite", "T") req.Header.Set("Destination", path.Join("/", testFileName+"1")) req.SetBasicAuth(u.Username, u.Password) resp, err = httpClient.Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusCreated, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestPropPatch(t *testing.T) { u := getTestUser() u.Username = u.Username + "1" localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser := getTestSFTPUser() sftpUser.FsConfig.SFTPConfig.Username = localUser.Username for _, u := range []dataprovider.User{getTestUser(), getTestUserWithCryptFs(), sftpUser} { user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client), sftpUser.Username) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) httpClient := httpclient.GetHTTPClient() propatchBody := `Wed, 04 Nov 2020 13:25:51 GMTSat, 05 Dec 2020 21:16:12 GMTWed, 04 Nov 2020 13:25:51 GMT00000000` req, err := http.NewRequest("PROPPATCH", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(propatchBody))) assert.NoError(t, err) req.SetBasicAuth(u.Username, u.Password) resp, err := httpClient.Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) info, err := client.Stat(testFileName) if assert.NoError(t, err) { expected, err := http.ParseTime("Wed, 04 Nov 2020 13:25:51 GMT") assert.NoError(t, err) assert.Equal(t, testFileSize, info.Size()) assert.Equal(t, expected.Format(http.TimeFormat), info.ModTime().Format(http.TimeFormat)) } // wrong date propatchBody = `Wed, 04 Nov 2020 13:25:51 GMTSat, 05 Dec 2020 21:16:12 GMTWid, 04 Nov 2020 13:25:51 GMT00000000` req, err = http.NewRequest("PROPPATCH", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(propatchBody))) assert.NoError(t, err) req.SetBasicAuth(u.Username, u.Password) resp, err = httpClient.Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) err = resp.Body.Close() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestLoginInvalidPwd(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) user.Password = "wrong" client = getWebDavClient(user, false, nil) assert.Error(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } func TestLoginNonExistentUser(t *testing.T) { user := getTestUser() client := getWebDavClient(user, true, nil) assert.Error(t, checkBasicFunc(client)) } func TestRateLimiter(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.RateLimitersConfig = []common.RateLimiterConfig{ { Average: 1, Period: 1000, Burst: 3, Type: 1, Protocols: []string{common.ProtocolWebDAV}, }, } err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) _, err = client.ReadDir(".") if assert.Error(t, err) { assert.Contains(t, err.Error(), "429") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestDefender(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 cfg.DefenderConfig.ScoreValid = 1 err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) user.Password = "wrong_pwd" client = getWebDavClient(user, false, nil) assert.Error(t, checkBasicFunc(client)) hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) assert.NoError(t, err) if assert.Len(t, hosts, 1) { host := hosts[0] assert.Empty(t, host.GetBanTime()) assert.Equal(t, 1, host.Score) } for i := 0; i < 2; i++ { client = getWebDavClient(user, false, nil) assert.Error(t, checkBasicFunc(client)) } user.Password = defaultPassword client = getWebDavClient(user, true, nil) err = checkBasicFunc(client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "403") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } func TestLoginExternalAuth(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) client := getWebDavClient(u, false, nil) assert.NoError(t, checkBasicFunc(client)) u.Username = defaultUsername + "1" client = getWebDavClient(u, false, nil) assert.Error(t, checkBasicFunc(client)) user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, defaultUsername, user.Username) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestExternalAuthPasswordChange(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, defaultPassword), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) client := getWebDavClient(u, false, nil) assert.NoError(t, checkBasicFunc(client)) u.Username = defaultUsername + "1" client = getWebDavClient(u, false, nil) assert.Error(t, checkBasicFunc(client)) err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, defaultPassword+"1"), os.ModePerm) assert.NoError(t, err) client = getWebDavClient(u, false, nil) assert.Error(t, checkBasicFunc(client)) u.Password = defaultPassword + "1" client = getWebDavClient(u, false, nil) assert.NoError(t, checkBasicFunc(client)) user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.Equal(t, defaultUsername, user.Username) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(defaultUsername+"1", http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestExternalAuthReturningAnonymousUser(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() u.Filters.IsAnonymous = true u.Filters.DeniedProtocols = []string{common.ProtocolSSH} u.Password = "" err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) client := getWebDavClient(u, false, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, u.Username, emptyPwdPlaceholder, false, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "403") } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.IsAnonymous) assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) u.Password = emptyPwdPlaceholder client = getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) if assert.Error(t, err) { assert.Contains(t, err.Error(), "403") } err = client.Mkdir("testdir", os.ModePerm) if assert.Error(t, err) { assert.Contains(t, err.Error(), "403") } err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestExternalAuthAnonymousGroupInheritance(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } g := dataprovider.Group{ BaseGroup: sdk.BaseGroup{ Name: "test_group", }, UserSettings: dataprovider.GroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{ Permissions: map[string][]string{ "/": allPerms, }, Filters: sdk.BaseUserFilters{ IsAnonymous: true, }, }, }, } u := getTestUser() u.Groups = []sdk.GroupMapping{ { Name: g.Name, Type: sdk.GroupTypePrimary, }, } err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) group, _, err := httpdtest.AddGroup(g, http.StatusCreated) assert.NoError(t, err) u.Password = emptyPwdPlaceholder client := getWebDavClient(u, false, nil) assert.NoError(t, checkBasicFunc(client)) err = client.Mkdir("tdir", os.ModePerm) if assert.Error(t, err) { assert.Contains(t, err.Error(), "403") } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) assert.False(t, user.Filters.IsAnonymous) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveGroup(group, http.StatusOK) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestPreLoginHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) providerConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) assert.NoError(t, err) client := getWebDavClient(u, true, nil) assert.NoError(t, checkBasicFunc(client)) user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) // test login with an existing user client = getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) assert.NoError(t, err) // update the user to remove it from the cache user.FsConfig.Provider = sdk.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, true, nil) assert.Error(t, checkBasicFunc(client)) // update the user to remove it from the cache user.FsConfig.Provider = sdk.LocalFilesystemProvider user.FsConfig.CryptConfig.Passphrase = nil user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user.Status = 0 err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) assert.NoError(t, err) client = getWebDavClient(user, true, nil) assert.Error(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestPreDownloadHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldExecuteOn := common.Config.Actions.ExecuteOn oldHook := common.Config.Actions.Hook common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} common.Config.Actions.Hook = preDownloadPath user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} common.Config.Actions.Hook = preDownloadPath common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook } func TestPreUploadHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } oldExecuteOn := common.Config.Actions.ExecuteOn oldHook := common.Config.Actions.Hook common.Config.Actions.ExecuteOn = []string{common.OperationPreUpload} common.Config.Actions.Hook = preUploadPath user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(preUploadPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = os.WriteFile(preUploadPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.Error(t, err) err = uploadFileWithRawClient(testFilePath, testFileName+"1", user.Username, defaultPassword, false, testFileSize, client) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook } func TestPostConnectHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } common.Config.PostConnectHook = postConnectPath u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = os.WriteFile(postConnectPath, getExitCodeScriptContent(0), os.ModePerm) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) err = os.WriteFile(postConnectPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) assert.Error(t, checkBasicFunc(client)) common.Config.PostConnectHook = "http://127.0.0.1:8078/healthz" assert.NoError(t, checkBasicFunc(client)) common.Config.PostConnectHook = "http://127.0.0.1:8078/notfound" assert.Error(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) common.Config.PostConnectHook = "" } func TestMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) // now add a fake connection fs := vfs.NewOsFs("id", os.TempDir(), "", nil) connection := &webdavd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } err = common.Connections.Add(connection) assert.NoError(t, err) assert.Error(t, checkBasicFunc(client)) common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.MaxTotalConnections = oldValue } func TestMaxPerHostConnections(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 1 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) // now add a fake connection addrs, err := net.LookupHost("localhost") assert.NoError(t, err) for _, addr := range addrs { common.Connections.AddClientConnection(addr) } assert.Error(t, checkBasicFunc(client)) for _, addr := range addrs { common.Connections.RemoveClientConnection(addr) } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.MaxPerHostConnections = oldValue } func TestMaxTransfers(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 2 assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) conn, sftpClient, err := getSftpClient(user) assert.NoError(t, err) defer conn.Close() defer sftpClient.Close() f1, err := sftpClient.Create("file1") assert.NoError(t, err) f2, err := sftpClient.Create("file2") assert.NoError(t, err) _, err = f1.Write([]byte(" ")) assert.NoError(t, err) _, err = f2.Write([]byte(" ")) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.Error(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = f1.Close() assert.NoError(t, err) err = f2.Close() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.MaxPerHostConnections = oldValue } func TestMustChangePasswordRequirement(t *testing.T) { u := getTestUser() u.Filters.RequirePasswordChange = true user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.Error(t, checkBasicFunc(client)) err = dataprovider.UpdateUserPassword(user.Username, defaultPassword, "", "", "") assert.NoError(t, err) client = getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestMaxSessions(t *testing.T) { u := getTestUser() u.MaxSessions = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) // now add a fake connection fs := vfs.NewOsFs("id", os.TempDir(), "", nil) connection := &webdavd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } err = common.Connections.Add(connection) assert.NoError(t, err) assert.Error(t, checkBasicFunc(client)) common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestLoginWithIPilters(t *testing.T) { u := getTestUser() u.Filters.DeniedIP = []string{"192.167.0.0/24", "172.18.0.0/16"} u.Filters.AllowedIP = []string{"172.19.0.0/16"} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.Error(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestDownloadErrors(t *testing.T) { u := getTestUser() u.QuotaFiles = 1 subDir1 := "sub1" subDir2 := "sub2" u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems} u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermDownload} // use an unknown mime to trigger content type detection u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/sub2", AllowedPatterns: []string{}, DeniedPatterns: []string{"*.jpg", "*.zipp"}, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) testFilePath1 := filepath.Join(user.HomeDir, subDir1, "file.zipp") testFilePath2 := filepath.Join(user.HomeDir, subDir2, "file.zipp") testFilePath3 := filepath.Join(user.HomeDir, subDir2, "file.jpg") err = os.MkdirAll(filepath.Dir(testFilePath1), os.ModePerm) assert.NoError(t, err) err = os.MkdirAll(filepath.Dir(testFilePath2), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFilePath1, []byte("file1"), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFilePath2, []byte("file2"), os.ModePerm) assert.NoError(t, err) err = os.WriteFile(testFilePath3, []byte("file3"), os.ModePerm) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(path.Join("/", subDir1, "file.zipp"), localDownloadPath, 5, client) assert.Error(t, err) err = downloadFile(path.Join("/", subDir2, "file.zipp"), localDownloadPath, 5, client) assert.Error(t, err) err = downloadFile(path.Join("/", subDir2, "file.jpg"), localDownloadPath, 5, client) assert.Error(t, err) err = downloadFile(path.Join("missing.zip"), localDownloadPath, 5, client) assert.Error(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUploadErrors(t *testing.T) { u := getTestUser() u.QuotaSize = 65535 subDir1 := "sub1" subDir2 := "sub2" // we need download permission to get size since PROPFIND will open the file u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems, dataprovider.PermDownload} u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermDownload} u.Filters.FilePatterns = []sdk.PatternsFilter{ { Path: "/sub2", AllowedPatterns: []string{}, DeniedPatterns: []string{"*.zip"}, }, } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := user.QuotaSize err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = client.Mkdir(subDir1, os.ModePerm) assert.NoError(t, err) err = client.Mkdir(subDir2, os.ModePerm) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(subDir1, testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.Error(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName+".zip"), user.Username, defaultPassword, true, testFileSize, client) assert.Error(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = client.Rename(path.Join(subDir2, testFileName), path.Join(subDir1, testFileName), false) assert.Error(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.Error(t, err) err = uploadFileWithRawClient(testFilePath, subDir1, user.Username, defaultPassword, true, testFileSize, client) assert.Error(t, err) // overquota err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.Error(t, err) err = client.Remove(path.Join(subDir2, testFileName)) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.Error(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestDeniedLoginMethod(t *testing.T) { u := getTestUser() u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.Error(t, checkBasicFunc(client)) user.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyAndKeyboardInt} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestDeniedProtocols(t *testing.T) { u := getTestUser() u.Filters.DeniedProtocols = []string{common.ProtocolWebDAV} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) assert.Error(t, checkBasicFunc(client)) user.Filters.DeniedProtocols = []string{common.ProtocolSSH, common.ProtocolFTP} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestQuotaLimits(t *testing.T) { u := getTestUser() u.QuotaFiles = 1 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaFiles = 1 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { testFileSize := int64(65536) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testFileSize1 := int64(131072) testFileName1 := "test_file1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) testFileSize2 := int64(32768) testFileName2 := "test_file2.dat" testFilePath2 := filepath.Join(homeBasePath, testFileName2) err = createTestFile(testFilePath2, testFileSize2) assert.NoError(t, err) client := getWebDavClient(user, false, nil) // test quota files err = uploadFileWithRawClient(testFilePath, testFileName+".quota", user.Username, defaultPassword, false, //nolint:goconst testFileSize, client) if !assert.NoError(t, err, "username: %v", user.Username) { info, err := os.Stat(testFilePath) if assert.NoError(t, err) { fmt.Printf("local file size: %v\n", info.Size()) } printLatestLogs(20) } err = uploadFileWithRawClient(testFilePath, testFileName+".quota1", user.Username, defaultPassword, false, testFileSize, client) assert.Error(t, err, "username: %v", user.Username) err = client.Rename(testFileName+".quota", testFileName, false) assert.NoError(t, err) files, err := client.ReadDir("/") assert.NoError(t, err) assert.Len(t, files, 1) // test quota size user.QuotaSize = testFileSize - 1 user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName+".quota", user.Username, defaultPassword, false, testFileSize, client) assert.Error(t, err) err = client.Rename(testFileName, testFileName+".quota", false) assert.NoError(t, err) // now test quota limits while uploading the current file, we have 1 bytes remaining user.QuotaSize = testFileSize + 1 user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, false, testFileSize1, client) assert.Error(t, err) _, err = client.Stat(testFileName1) assert.Error(t, err) err = client.Rename(testFileName+".quota", testFileName, false) assert.NoError(t, err) // overwriting an existing file will work if the resulting size is lesser or equal than the current one err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath2, testFileName, user.Username, defaultPassword, false, testFileSize2, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath1, testFileName, user.Username, defaultPassword, false, testFileSize1, client) assert.Error(t, err) err = uploadFileWithRawClient(testFilePath2, testFileName, user.Username, defaultPassword, false, testFileSize2, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) err = os.Remove(testFilePath2) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 user.QuotaFiles = 0 user.QuotaSize = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestTransferQuotaLimits(t *testing.T) { u := getTestUser() u.DownloadDataTransfer = 1 u.UploadDataTransfer = 1 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(550000) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) client := getWebDavClient(user, false, nil) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) // error while download is active err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err) // error before starting the download err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err) // error while upload is active err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.Error(t, err) // error before starting the upload err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.Error(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUploadMaxSize(t *testing.T) { testFileSize := int64(65535) u := getTestUser() u.Filters.MaxUploadFileSize = testFileSize + 1 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.Filters.MaxUploadFileSize = testFileSize + 1 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) testFileSize1 := int64(131072) testFileName1 := "test_file_dav1.dat" testFilePath1 := filepath.Join(homeBasePath, testFileName1) err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) client := getWebDavClient(user, false, nil) err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, false, testFileSize1, client) assert.Error(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) // now test overwrite an existing file with a size bigger than the allowed one err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName1), testFileSize1) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, false, testFileSize1, client) assert.Error(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(testFilePath1) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Filters.MaxUploadFileSize = 65536000 user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestClientClose(t *testing.T) { u := getTestUser() u.UploadBandwidth = 64 u.DownloadBandwidth = 64 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.UploadBandwidth = 64 u.DownloadBandwidth = 64 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { testFileSize := int64(1048576) testFilePath := filepath.Join(homeBasePath, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) var wg sync.WaitGroup wg.Add(1) go func() { err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.Error(t, err) wg.Done() }() assert.Eventually(t, func() bool { for _, stat := range common.Connections.GetStats("") { if len(stat.Transfers) > 0 { return true } } return false }, 1*time.Second, 50*time.Millisecond) for _, stat := range common.Connections.GetStats("") { common.Connections.Close(stat.ConnectionID, "") } wg.Wait() // for the sftp user a stat is done after the failed upload and // this triggers a new connection for _, stat := range common.Connections.GetStats("") { common.Connections.Close(stat.ConnectionID, "") } assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) err = os.Remove(testFilePath) assert.NoError(t, err) testFilePath = filepath.Join(user.HomeDir, testFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) wg.Add(1) go func() { err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.Error(t, err) wg.Done() }() assert.Eventually(t, func() bool { for _, stat := range common.Connections.GetStats("") { if len(stat.Transfers) > 0 { return true } } return false }, 1*time.Second, 50*time.Millisecond) for _, stat := range common.Connections.GetStats("") { common.Connections.Close(stat.ConnectionID, "") } wg.Wait() assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(localDownloadPath) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestLoginWithDatabaseCredentials(t *testing.T) { u := getTestUser() u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account", "private_key": " ", "client_email": "example@iam.gserviceaccount.com" }`) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) assert.NotEmpty(t, user.FsConfig.GCSConfig.Credentials.GetPayload()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) client := getWebDavClient(user, false, nil) err = client.Connect() assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestLoginInvalidFs(t *testing.T) { u := getTestUser() u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) assert.Error(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestSFTPBuffered(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaFiles = 1000 u.HomeDir = filepath.Join(os.TempDir(), u.Username) u.FsConfig.SFTPConfig.BufferSize = 2 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(sftpUser, true, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) expectedQuotaSize := testFileSize expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) // overwrite an existing file err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) user, _, err := httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) fileContent := []byte("test file contents") err = os.WriteFile(testFilePath, fileContent, os.ModePerm) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, true, int64(len(fileContent)), client) assert.NoError(t, err) remotePath := fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName) req, err := http.NewRequest(http.MethodGet, remotePath, nil) assert.NoError(t, err) httpClient := httpclient.GetHTTPClient() req.SetBasicAuth(user.Username, defaultPassword) req.Header.Set("Range", "bytes=5-") resp, err := httpClient.Do(req) if assert.NoError(t, err) { defer resp.Body.Close() assert.Equal(t, http.StatusPartialContent, resp.StatusCode) bodyBytes, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, "file contents", string(bodyBytes)) } req.Header.Set("Range", "bytes=5-8") resp, err = httpClient.Do(req) if assert.NoError(t, err) { defer resp.Body.Close() assert.Equal(t, http.StatusPartialContent, resp.StatusCode) bodyBytes, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, "file", string(bodyBytes)) } err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(sftpUser.GetHomeDir()) assert.NoError(t, err) } func TestBytesRangeRequests(t *testing.T) { u := getTestUser() u.Username = u.Username + "1" localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) sftpUser := getTestSFTPUser() sftpUser.FsConfig.SFTPConfig.Username = localUser.Username for _, u := range []dataprovider.User{getTestUser(), getTestUserWithCryptFs(), sftpUser} { user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) testFileName := "test_file.txt" testFilePath := filepath.Join(homeBasePath, testFileName) fileContent := []byte("test file contents") err = os.WriteFile(testFilePath, fileContent, os.ModePerm) assert.NoError(t, err) client := getWebDavClient(user, true, nil) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, int64(len(fileContent)), client) assert.NoError(t, err) remotePath := fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName) req, err := http.NewRequest(http.MethodGet, remotePath, nil) if assert.NoError(t, err) { httpClient := httpclient.GetHTTPClient() req.SetBasicAuth(user.Username, defaultPassword) req.Header.Set("Range", "bytes=5-") resp, err := httpClient.Do(req) if assert.NoError(t, err) { defer resp.Body.Close() assert.Equal(t, http.StatusPartialContent, resp.StatusCode) bodyBytes, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, "file contents", string(bodyBytes)) } req.Header.Set("Range", "bytes=5-8") resp, err = httpClient.Do(req) if assert.NoError(t, err) { defer resp.Body.Close() assert.Equal(t, http.StatusPartialContent, resp.StatusCode) bodyBytes, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, "file", string(bodyBytes)) } } // seek on a missing file remotePath = fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName+"_missing") req, err = http.NewRequest(http.MethodGet, remotePath, nil) if assert.NoError(t, err) { httpClient := httpclient.GetHTTPClient() req.SetBasicAuth(user.Username, defaultPassword) req.Header.Set("Range", "bytes=5-") resp, err := httpClient.Do(req) if assert.NoError(t, err) { defer resp.Body.Close() assert.Equal(t, http.StatusNotFound, resp.StatusCode) } } err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestContentTypeGET(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(64) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) client := getWebDavClient(user, true, nil) err = uploadFileWithRawClient(testFilePath, testFileName+".sftpgo", user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) remotePath := fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName+".sftpgo") req, err := http.NewRequest(http.MethodGet, remotePath, nil) if assert.NoError(t, err) { httpClient := httpclient.GetHTTPClient() req.SetBasicAuth(user.Username, defaultPassword) resp, err := httpClient.Do(req) if assert.NoError(t, err) { defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, "application/sftpgo", resp.Header.Get("Content-Type")) } } err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestHEAD(t *testing.T) { u := getTestUser() user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) rootPath := fmt.Sprintf("http://%v", webDavServerAddr) httpClient := httpclient.GetHTTPClient() req, err := http.NewRequest(http.MethodHead, rootPath, nil) if assert.NoError(t, err) { req.SetBasicAuth(u.Username, u.Password) resp, err := httpClient.Do(req) if assert.NoError(t, err) { assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) assert.Equal(t, "text/xml; charset=utf-8", resp.Header.Get("Content-Type")) resp.Body.Close() } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestGETAsPROPFIND(t *testing.T) { u := getTestUser() subDir1 := "/sub1" u.Permissions[subDir1] = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) rootPath := fmt.Sprintf("http://%v/", webDavServerAddr) httpClient := httpclient.GetHTTPClient() req, err := http.NewRequest(http.MethodGet, rootPath, nil) if assert.NoError(t, err) { req.SetBasicAuth(u.Username, u.Password) resp, err := httpClient.Do(req) if assert.NoError(t, err) { assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) resp.Body.Close() } } client := getWebDavClient(user, false, nil) err = client.MkdirAll(path.Join(subDir1, "sub", "sub1"), os.ModePerm) assert.NoError(t, err) subPath := fmt.Sprintf("http://%v/%v", webDavServerAddr, subDir1) req, err = http.NewRequest(http.MethodGet, subPath, nil) if assert.NoError(t, err) { req.SetBasicAuth(u.Username, u.Password) resp, err := httpClient.Do(req) if assert.NoError(t, err) { // before the performance patch we have a 500 here, now we have 207 but an empty list //assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) resp.Body.Close() } } // we cannot stat the sub at all subPath1 := fmt.Sprintf("http://%v/%v", webDavServerAddr, path.Join(subDir1, "sub")) req, err = http.NewRequest(http.MethodGet, subPath1, nil) if assert.NoError(t, err) { req.SetBasicAuth(u.Username, u.Password) resp, err := httpClient.Do(req) if assert.NoError(t, err) { // here the stat will fail, so the request will not be changed in propfind assert.Equal(t, http.StatusForbidden, resp.StatusCode) resp.Body.Close() } } // we have no permission, we get an empty list files, err := client.ReadDir(subDir1) assert.NoError(t, err) assert.Len(t, files, 0) // if we grant the permissions the files are listed user.Permissions[subDir1] = []string{dataprovider.PermDownload, dataprovider.PermListItems} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) files, err = client.ReadDir(subDir1) assert.NoError(t, err) assert.Len(t, files, 1) // PROPFIND with infinity depth is forbidden req, err = http.NewRequest(http.MethodGet, rootPath, nil) if assert.NoError(t, err) { req.SetBasicAuth(u.Username, u.Password) req.Header.Set("Depth", "infinity") resp, err := httpClient.Do(req) if assert.NoError(t, err) { assert.Equal(t, http.StatusForbidden, resp.StatusCode) resp.Body.Close() } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestStat(t *testing.T) { u := getTestUser() u.Permissions["/subdir"] = []string{dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermDownload} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) subDir := "subdir" testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = client.Mkdir(subDir, os.ModePerm) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join("/", subDir, testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) user.Permissions["/subdir"] = []string{dataprovider.PermUpload, dataprovider.PermDownload} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) _, err = client.Stat(testFileName) assert.NoError(t, err) _, err = client.Stat(path.Join("/", subDir, testFileName)) assert.Error(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestUploadOverwriteVfolder(t *testing.T) { u := getTestUser() u.QuotaFiles = 1000 vdir := "/vdir" mappedPath := filepath.Join(os.TempDir(), "mappedDir") folderName := filepath.Base(mappedPath) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdir, QuotaSize: -1, QuotaFiles: -1, }) err = os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) files, err := client.ReadDir(".") assert.NoError(t, err) vdirFound := false for _, info := range files { if info.Name() == path.Base(vdir) { vdirFound = true break } } assert.True(t, vdirFound) info, err := client.Stat(vdir) if assert.NoError(t, err) { assert.Equal(t, path.Base(vdir), info.Name()) } testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(vdir, testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), folder.UsedQuotaSize) assert.Equal(t, 0, folder.UsedQuotaFiles) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) err = uploadFileWithRawClient(testFilePath, path.Join(vdir, testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, int64(0), folder.UsedQuotaSize) assert.Equal(t, 0, folder.UsedQuotaFiles) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, 1, user.UsedQuotaFiles) err = os.Remove(testFilePath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestOsErrors(t *testing.T) { u := getTestUser() vdir := "/vdir" mappedPath := filepath.Join(os.TempDir(), "mappedDir") folderName := filepath.Base(mappedPath) u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdir, QuotaSize: -1, QuotaFiles: -1, }) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, false, nil) files, err := client.ReadDir(".") assert.NoError(t, err) assert.Len(t, files, 1) info, err := client.Stat(vdir) assert.NoError(t, err) assert.True(t, info.IsDir()) // now remove the folder mapped to vdir. It still appear in directory listing // virtual folders are automatically added err = os.RemoveAll(mappedPath) assert.NoError(t, err) files, err = client.ReadDir(".") assert.NoError(t, err) assert.Len(t, files, 1) err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName), 32768) assert.NoError(t, err) files, err = client.ReadDir(".") assert.NoError(t, err) if assert.Len(t, files, 2) { var names []string for _, info := range files { names = append(names, info.Name()) } assert.Contains(t, names, testFileName) assert.Contains(t, names, "vdir") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) } func TestMiscCommands(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() u.QuotaFiles = 100 sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) for _, user := range []dataprovider.User{localUser, sftpUser} { dir := "testDir" client := getWebDavClient(user, true, nil) err = client.MkdirAll(path.Join(dir, "sub1", "sub2"), os.ModePerm) assert.NoError(t, err) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(dir, testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(dir, "sub1", testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(dir, "sub1", "sub2", testFileName), user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = client.Copy(dir, dir+"_copy", false) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 6, user.UsedQuotaFiles) assert.Equal(t, 6*testFileSize, user.UsedQuotaSize) err = client.Copy(dir, dir+"_copy1", false) //nolint:goconst assert.NoError(t, err) err = client.Copy(dir+"_copy", dir+"_copy1", false) assert.Error(t, err) err = client.Copy(dir+"_copy", dir+"_copy1", true) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 9, user.UsedQuotaFiles) assert.Equal(t, 9*testFileSize, user.UsedQuotaSize) err = client.Rename(dir+"_copy1", dir+"_copy2", false) assert.NoError(t, err) err = client.Remove(path.Join(dir+"_copy", testFileName)) assert.NoError(t, err) err = client.Rename(dir+"_copy2", dir+"_copy", true) assert.NoError(t, err) err = client.Copy(dir+"_copy", dir+"_copy1", false) assert.NoError(t, err) err = client.RemoveAll(dir + "_copy1") assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, 6, user.UsedQuotaFiles) assert.Equal(t, 6*testFileSize, user.UsedQuotaSize) err = os.Remove(testFilePath) assert.NoError(t, err) if user.Username == defaultUsername { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 user.CreatedAt = 0 user.QuotaFiles = 0 _, resp, err := httpdtest.AddUser(user, http.StatusCreated) assert.NoError(t, err, string(resp)) } } _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) } func TestClientCertificateAuthRevokedCert(t *testing.T) { u := getTestUser() u.Username = tlsClient2Username u.Filters.TLSUsername = sdk.TLSUsernameCN user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) client := getWebDavClient(user, true, tlsConfig) err = checkBasicFunc(client) if assert.Error(t, err) { if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "broken pipe") { t.Errorf("unexpected error: %v", err) } } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestClientCertificateAuth(t *testing.T) { u := getTestUser() u.Username = tlsClient1Username u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificateAndPwd} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) // TLS username is not enabled, mutual TLS should fail resp, err := getTLSHTTPClient(tlsConfig).Get(fmt.Sprintf("https://%v/", webDavTLSServerAddr)) if assert.NoError(t, err) { defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, string(body)) } user.Filters.TLSUsername = sdk.TLSUsernameCN user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client := getWebDavClient(user, true, tlsConfig) err = checkBasicFunc(client) assert.NoError(t, err) user.Filters.TLSUsername = sdk.TLSUsernameNone user.Filters.TLSCerts = []string{client1Crt} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, true, tlsConfig) err = checkBasicFunc(client) assert.NoError(t, err) user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificate} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, true, tlsConfig) err = checkBasicFunc(client) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestWrongClientCertificate(t *testing.T) { u := getTestUser() u.Username = tlsClient2Username u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificateAndPwd} u.Filters.TLSUsername = sdk.TLSUsernameCN user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) // the certificate common name is client1 and it does not exists resp, err := getTLSHTTPClient(tlsConfig).Get(fmt.Sprintf("https://%v/", webDavTLSServerAddr)) if assert.NoError(t, err) { defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, string(body)) } user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificate} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // now create client1 u = getTestUser() u.Username = tlsClient1Username u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificate} u.Filters.TLSUsername = sdk.TLSUsernameCN user1, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) resp, err = getTLSHTTPClient(tlsConfig).Get(fmt.Sprintf("https://%v:%v@%v/", tlsClient2Username, defaultPassword, webDavTLSServerAddr)) if assert.NoError(t, err) { defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, string(body)) assert.Contains(t, string(body), "invalid credentials") } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) } func TestClientCertificateAuthCachedUser(t *testing.T) { u := getTestUser() u.Username = tlsClient1Username u.Filters.TLSUsername = sdk.TLSUsernameCN u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificateAndPwd} user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) client := getWebDavClient(user, true, tlsConfig) err = checkBasicFunc(client) assert.NoError(t, err) // the user is now cached without a password, try a simple password login with and without TLS client = getWebDavClient(user, true, nil) err = checkBasicFunc(client) assert.NoError(t, err) client = getWebDavClient(user, false, nil) err = checkBasicFunc(client) assert.NoError(t, err) // and now with a wrong password user.Password = "wrong" client = getWebDavClient(user, false, nil) err = checkBasicFunc(client) assert.Error(t, err) // allow cert+password only user.Password = "" user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificate} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, true, tlsConfig) err = checkBasicFunc(client) assert.NoError(t, err) // the user is now cached client = getWebDavClient(user, true, tlsConfig) err = checkBasicFunc(client) assert.NoError(t, err) // password auth should work too client = getWebDavClient(user, false, nil) err = checkBasicFunc(client) assert.NoError(t, err) client = getWebDavClient(user, true, nil) err = checkBasicFunc(client) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) } func TestExternatAuthWithClientCert(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() u.Username = tlsClient1Username u.Filters.TLSUsername = sdk.TLSUsernameCN u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodPassword} err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, ""), os.ModePerm) assert.NoError(t, err) providerConf.ExternalAuthHook = extAuthPath providerConf.ExternalAuthScope = 0 err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) client := getWebDavClient(u, true, tlsConfig) assert.NoError(t, checkBasicFunc(client)) resp, err := getTLSHTTPClient(tlsConfig).Get(fmt.Sprintf("https://%v:%v@%v/", tlsClient2Username, defaultPassword, webDavTLSServerAddr)) if assert.NoError(t, err) { defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, string(body)) assert.Contains(t, string(body), "invalid credentials") } user, _, err := httpdtest.GetUserByUsername(tlsClient1Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, tlsClient1Username, user.Username) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(extAuthPath) assert.NoError(t, err) } func TestPreLoginHookWithClientCert(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } u := getTestUser() u.Username = tlsClient1Username u.Filters.TLSUsername = sdk.TLSUsernameCN u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodPassword} err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) assert.NoError(t, err) providerConf.PreLoginHook = preLoginPath err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) _, _, err = httpdtest.GetUserByUsername(tlsClient1Username, http.StatusNotFound) assert.NoError(t, err) tlsConfig := &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) assert.NoError(t, err) tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) client := getWebDavClient(u, true, tlsConfig) assert.NoError(t, checkBasicFunc(client)) user, _, err := httpdtest.GetUserByUsername(tlsClient1Username, http.StatusOK) assert.NoError(t, err) // test login with an existing user client = getWebDavClient(user, true, tlsConfig) assert.NoError(t, checkBasicFunc(client)) err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) assert.NoError(t, err) // update the user to remove it from the cache user.Password = defaultPassword user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, true, tlsConfig) assert.Error(t, checkBasicFunc(client)) // update the user to remove it from the cache user.Password = defaultPassword user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user.Status = 0 err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) assert.NoError(t, err) client = getWebDavClient(user, true, tlsConfig) assert.Error(t, checkBasicFunc(client)) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) err = dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf = config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) err = os.Remove(preLoginPath) assert.NoError(t, err) } func TestSFTPLoopVirtualFolders(t *testing.T) { user1 := getTestUser() user2 := getTestUser() user1.Username += "1" user2.Username += "2" // user1 is a local account with a virtual SFTP folder to user2 // user2 has user1 as SFTP fs folderName := "sftp" user1.VirtualFolders = append(user1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: "/vdir", }) user2.FsConfig.Provider = sdk.SFTPFilesystemProvider user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user1.Username, }, Password: kms.NewPlainSecret(defaultPassword), } f := vfs.BaseVirtualFolder{ Name: folderName, FsConfig: vfs.Filesystem{ Provider: sdk.SFTPFilesystemProvider, SFTPConfig: vfs.SFTPFsConfig{ BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ Endpoint: sftpServerAddr, Username: user2.Username, }, Password: kms.NewPlainSecret(defaultPassword), }, }, } _, _, err := httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) assert.NoError(t, err, string(resp)) user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) assert.NoError(t, err, string(resp)) client := getWebDavClient(user1, true, nil) testDir := "tdir" err = client.Mkdir(testDir, os.ModePerm) assert.NoError(t, err) contents, err := client.ReadDir("/") assert.NoError(t, err) if assert.Len(t, contents, 2) { expected := 0 for _, info := range contents { switch info.Name() { case testDir, "vdir": assert.True(t, info.IsDir()) expected++ default: t.Errorf("unexpected file/dir %q", info.Name()) } } assert.Equal(t, expected, 2) } _, err = httpdtest.RemoveUser(user1, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user1.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveUser(user2, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user2.GetHomeDir()) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) } func TestNestedVirtualFolders(t *testing.T) { u := getTestUser() localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) u = getTestSFTPUser() mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") folderNameCrypt := filepath.Base(mappedPathCrypt) vdirCryptPath := "/vdir/crypt" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameCrypt, }, VirtualPath: vdirCryptPath, }) mappedPath := filepath.Join(os.TempDir(), "local") folderName := filepath.Base(mappedPath) vdirPath := "/vdir/local" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderName, }, VirtualPath: vdirPath, }) mappedPathNested := filepath.Join(os.TempDir(), "nested") folderNameNested := filepath.Base(mappedPathNested) vdirNestedPath := "/vdir/crypt/nested" u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ Name: folderNameNested, }, VirtualPath: vdirNestedPath, QuotaFiles: -1, QuotaSize: -1, }) f1 := vfs.BaseVirtualFolder{ Name: folderNameCrypt, FsConfig: vfs.Filesystem{ Provider: sdk.CryptedFilesystemProvider, CryptConfig: vfs.CryptFsConfig{ Passphrase: kms.NewPlainSecret(defaultPassword), }, }, MappedPath: mappedPathCrypt, } _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) assert.NoError(t, err) f2 := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, } _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) assert.NoError(t, err) f3 := vfs.BaseVirtualFolder{ Name: folderNameNested, MappedPath: mappedPathNested, } _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) assert.NoError(t, err) sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(sftpUser, true, nil) assert.NoError(t, checkBasicFunc(client)) testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join("/vdir", testFileName), sftpUser.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(vdirPath, testFileName), sftpUser.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(vdirCryptPath, testFileName), sftpUser.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) err = uploadFileWithRawClient(testFilePath, path.Join(vdirNestedPath, testFileName), sftpUser.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) assert.NoError(t, err) _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(mappedPathCrypt) assert.NoError(t, err) err = os.RemoveAll(mappedPath) assert.NoError(t, err) err = os.RemoveAll(mappedPathNested) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func checkBasicFunc(client *gowebdav.Client) error { err := client.Connect() if err != nil { return err } _, err = client.ReadDir("/") return err } func checkFileSize(remoteDestPath string, expectedSize int64, client *gowebdav.Client) error { info, err := client.Stat(remoteDestPath) if err != nil { return err } if info.Size() != expectedSize { return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", info.Size(), expectedSize) } return nil } func uploadFileWithRawClient(localSourcePath string, remoteDestPath string, username, password string, useTLS bool, expectedSize int64, client *gowebdav.Client, headers ...dataprovider.KeyValue, ) error { srcFile, err := os.Open(localSourcePath) if err != nil { return err } defer srcFile.Close() var tlsConfig *tls.Config rootPath := fmt.Sprintf("http://%v/", webDavServerAddr) if useTLS { rootPath = fmt.Sprintf("https://%v/", webDavTLSServerAddr) tlsConfig = &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } } req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("%v%v", rootPath, remoteDestPath), srcFile) if err != nil { return err } req.SetBasicAuth(username, password) for _, kv := range headers { req.Header.Set(kv.Key, kv.Value) } httpClient := &http.Client{Timeout: 10 * time.Second} if tlsConfig != nil { customTransport := http.DefaultTransport.(*http.Transport).Clone() customTransport.TLSClientConfig = tlsConfig httpClient.Transport = customTransport } defer httpClient.CloseIdleConnections() resp, err := httpClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusCreated { return fmt.Errorf("unexpected status code: %v", resp.StatusCode) } if expectedSize > 0 { return checkFileSize(remoteDestPath, expectedSize, client) } return nil } // This method is buggy. I have to find time to better investigate and eventually report the issue upstream. // For now we upload using the uploadFileWithRawClient method /*func uploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *gowebdav.Client) error { srcFile, err := os.Open(localSourcePath) if err != nil { return err } defer srcFile.Close() err = client.WriteStream(remoteDestPath, srcFile, os.ModePerm) if err != nil { return err } if expectedSize > 0 { return checkFileSize(remoteDestPath, expectedSize, client) } return nil }*/ func downloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *gowebdav.Client) error { downloadDest, err := os.Create(localDestPath) if err != nil { return err } defer downloadDest.Close() reader, err := client.ReadStream(remoteSourcePath) if err != nil { return err } defer reader.Close() written, err := io.Copy(downloadDest, reader) if err != nil { return err } if written != expectedSize { return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", written, expectedSize) } return nil } func getTLSHTTPClient(tlsConfig *tls.Config) *http.Client { customTransport := http.DefaultTransport.(*http.Transport).Clone() customTransport.TLSClientConfig = tlsConfig return &http.Client{ Timeout: 5 * time.Second, Transport: customTransport, } } func getWebDavClient(user dataprovider.User, useTLS bool, tlsConfig *tls.Config) *gowebdav.Client { rootPath := fmt.Sprintf("http://%v/", webDavServerAddr) if useTLS { rootPath = fmt.Sprintf("https://%v/", webDavTLSServerAddr) if tlsConfig == nil { tlsConfig = &tls.Config{ ServerName: "localhost", InsecureSkipVerify: true, // use this for tests only MinVersion: tls.VersionTLS12, } } } pwd := defaultPassword if user.Password != "" { if user.Password == emptyPwdPlaceholder { pwd = "" } else { pwd = user.Password } } client := gowebdav.NewClient(rootPath, user.Username, pwd) client.SetTimeout(10 * time.Second) if tlsConfig != nil { customTransport := http.DefaultTransport.(*http.Transport).Clone() customTransport.TLSClientConfig = tlsConfig client.SetTransport(customTransport) } return client } func waitTCPListening(address string) { for { conn, err := net.Dial("tcp", address) if err != nil { logger.WarnToConsole("tcp server %v not listening: %v", address, err) time.Sleep(100 * time.Millisecond) continue } logger.InfoToConsole("tcp server %v now listening", address) conn.Close() break } } func getTestUser() dataprovider.User { user := dataprovider.User{ BaseUser: sdk.BaseUser{ Username: defaultUsername, Password: defaultPassword, HomeDir: filepath.Join(homeBasePath, defaultUsername), Status: 1, ExpirationDate: 0, }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = allPerms return user } func getTestSFTPUser() dataprovider.User { u := getTestUser() u.Username = u.Username + "_sftp" u.FsConfig.Provider = sdk.SFTPFilesystemProvider u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) return u } func getTestUserWithCryptFs() dataprovider.User { user := getTestUser() user.FsConfig.Provider = sdk.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("testPassphrase") return user } func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { var sftpClient *sftp.Client config := &ssh.ClientConfig{ User: user.Username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } if user.Password != "" { config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} } else { config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { return conn, sftpClient, err } sftpClient, err = sftp.NewClient(conn) if err != nil { conn.Close() } return conn, sftpClient, err } func getEncryptedFileSize(size int64) (int64, error) { encSize, err := sio.EncryptedSize(uint64(size)) return int64(encSize) + 33, err } func getExtAuthScriptContent(user dataprovider.User, password string) []byte { extAuthContent := []byte("#!/bin/sh\n\n") if password != "" { extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("if test \"$SFTPGO_AUTHD_USERNAME\" = \"%s\" -a \"$SFTPGO_AUTHD_PASSWORD\" = \"%s\"; then\n", user.Username, password))...) } else { extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("if test \"$SFTPGO_AUTHD_USERNAME\" = \"%s\"; then\n", user.Username))...) } u, _ := json.Marshal(user) extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("echo '%s'\n", string(u)))...) extAuthContent = append(extAuthContent, []byte("else\n")...) extAuthContent = append(extAuthContent, []byte("echo '{\"username\":\"\"}'\n")...) extAuthContent = append(extAuthContent, []byte("fi\n")...) return extAuthContent } func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte { content := []byte("#!/bin/sh\n\n") if nonJSONResponse { content = append(content, []byte("echo 'text response'\n")...) return content } if len(user.Username) > 0 { u, _ := json.Marshal(user) content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) } return content } func getExitCodeScriptContent(exitCode int) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) return content } func createTestFile(path string, size int64) error { baseDir := filepath.Dir(path) if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(baseDir, os.ModePerm) if err != nil { return err } } content := make([]byte, size) _, err := rand.Read(content) if err != nil { return err } err = os.WriteFile(path, content, os.ModePerm) if err != nil { return err } fi, err := os.Stat(path) if err != nil { return err } if fi.Size() != size { return fmt.Errorf("unexpected size %v, expected %v", fi.Size(), size) } return nil } func printLatestLogs(maxNumberOfLines int) { var lines []string f, err := os.Open(logFilePath) if err != nil { return } defer f.Close() scanner := bufio.NewScanner(f) for scanner.Scan() { lines = append(lines, scanner.Text()+"\r\n") for len(lines) > maxNumberOfLines { lines = lines[1:] } } if scanner.Err() != nil { logger.WarnToConsole("Unable to print latest logs: %v", scanner.Err()) return } for _, line := range lines { logger.DebugToConsole("%s", line) } } ================================================ FILE: main.go ================================================ // Copyright (C) 2019 Nicola Murino // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, version 3. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . // Fully featured and highly configurable SFTP server with optional // FTP/S and WebDAV support. // For more details about features, installation, configuration and usage // please refer to the README inside the source tree: // https://github.com/drakkan/sftpgo/blob/main/README.md package main // import "github.com/drakkan/sftpgo" import ( "github.com/drakkan/sftpgo/v2/internal/cmd" ) func main() { cmd.Execute() } ================================================ FILE: openapi/httpfs.yaml ================================================ openapi: 3.0.3 tags: - name: fs info: title: SFTPGo HTTPFs description: | SFTPGo can use custom storage backend implementations compliant with the API defined here. HTTPFs is a work in progress and makes no API stability promises. version: 0.1.0 license: name: AGPL-3.0-only url: 'https://www.gnu.org/licenses/agpl-3.0.en.html' servers: - url: /v1 security: - ApiKeyAuth: [] - BasicAuth: [] paths: /stat/{name}: parameters: - name: name in: path description: object name required: true schema: type: string get: tags: - fs summary: Describes the named object operationId: stat responses: 200: description: successful operation content: application/json: schema: $ref: '#/components/schemas/FileInfo' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /open/{name}: parameters: - name: name in: path description: object name required: true schema: type: string - name: offset in: query description: 'offset, in bytes, from the start. If not specified 0 must be assumed' required: false schema: type: integer format: int64 get: tags: - fs summary: Opens the named file for reading operationId: open responses: '200': description: successful operation content: '*/*': schema: type: string format: binary 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /create/{name}: parameters: - name: name in: path description: object name required: true schema: type: string - name: flags in: query description: 'flags to use for opening the file, if omitted O_RDWR|O_CREATE|O_TRUNC must be assumed. Supported flags: https://pkg.go.dev/os#pkg-constants' required: false schema: type: integer format: int32 - name: checks in: query description: 'If set to `1`, the parent directory must exist before creating the file' required: false schema: type: integer format: int32 post: tags: - fs summary: Creates or opens the named file for writing operationId: create requestBody: content: '*/*': schema: type: string format: binary required: true responses: 201: $ref: '#/components/responses/OKResponse' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /rename/{name}: parameters: - name: name in: path description: object name required: true schema: type: string - name: target in: query description: target name required: true schema: type: string patch: tags: - fs summary: Renames (moves) source to target operationId: rename responses: 200: $ref: '#/components/responses/OKResponse' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /remove/{name}: parameters: - name: name in: path description: object name required: true schema: type: string delete: tags: - fs summary: Removes the named file or (empty) directory. operationId: delete responses: 200: $ref: '#/components/responses/OKResponse' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /mkdir/{name}: parameters: - name: name in: path description: object name required: true schema: type: string post: tags: - fs summary: Creates a new directory with the specified name operationId: mkdir responses: 200: $ref: '#/components/responses/OKResponse' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /chmod/{name}: parameters: - name: name in: path description: object name required: true schema: type: string - name: mode in: query required: true schema: type: integer patch: tags: - fs summary: Changes the mode of the named file operationId: chmod responses: 200: $ref: '#/components/responses/OKResponse' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /chtimes/{name}: parameters: - name: name in: path description: object name required: true schema: type: string - name: access_time in: query required: true schema: type: string format: date-time - name: modification_time in: query required: true schema: type: string format: date-time patch: tags: - fs summary: Changes the access and modification time of the named file operationId: chtimes responses: 200: $ref: '#/components/responses/OKResponse' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /truncate/{name}: parameters: - name: name in: path description: object name required: true schema: type: string - name: size in: query required: true description: 'new file size in bytes' schema: type: integer format: int64 patch: tags: - fs summary: Changes the size of the named file operationId: truncate responses: 200: $ref: '#/components/responses/OKResponse' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /readdir/{name}: parameters: - name: name in: path description: object name required: true schema: type: string get: tags: - fs summary: Reads the named directory and returns the contents operationId: readdir responses: 200: description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/FileInfo' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /dirsize/{name}: parameters: - name: name in: path description: object name required: true schema: type: string get: tags: - fs summary: Returns the number of files and the size for the named directory including any sub-directory operationId: dirsize responses: 200: description: successful operation content: application/json: schema: type: object properties: files: type: integer description: 'Total number of files' size: type: integer format: int64 description: 'Total size of files' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /mimetype/{name}: parameters: - name: name in: path description: object name required: true schema: type: string get: tags: - fs summary: Returns the mime type for the named file operationId: mimetype responses: 200: description: successful operation content: application/json: schema: type: object properties: mime: type: string 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' /statvfs/{name}: parameters: - name: name in: path description: object name required: true schema: type: string get: tags: - fs summary: Returns the VFS stats for the specified path operationId: statvfs responses: 200: description: successful operation content: application/json: schema: $ref: '#/components/schemas/StatVFS' 401: $ref: '#/components/responses/Unauthorized' 403: $ref: '#/components/responses/Forbidden' 404: $ref: '#/components/responses/NotFound' 500: $ref: '#/components/responses/InternalServerError' 501: $ref: '#/components/responses/NotImplemented' default: $ref: '#/components/responses/DefaultResponse' components: responses: OKResponse: description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' BadRequest: description: Bad Request content: application/json: schema: $ref: '#/components/schemas/ApiResponse' Unauthorized: description: Unauthorized content: application/json: schema: $ref: '#/components/schemas/ApiResponse' Forbidden: description: Forbidden content: application/json: schema: $ref: '#/components/schemas/ApiResponse' NotFound: description: Not Found content: application/json: schema: $ref: '#/components/schemas/ApiResponse' NotImplemented: description: Not Implemented content: application/json: schema: $ref: '#/components/schemas/ApiResponse' Conflict: description: Conflict content: application/json: schema: $ref: '#/components/schemas/ApiResponse' RequestEntityTooLarge: description: Request Entity Too Large, max allowed size exceeded content: application/json: schema: $ref: '#/components/schemas/ApiResponse' InternalServerError: description: Internal Server Error content: application/json: schema: $ref: '#/components/schemas/ApiResponse' DefaultResponse: description: Unexpected Error content: application/json: schema: $ref: '#/components/schemas/ApiResponse' schemas: ApiResponse: type: object properties: message: type: string description: 'message, can be empty' error: type: string description: error description if any FileInfo: type: object properties: name: type: string description: base name of the file size: type: integer format: int64 description: length in bytes for regular files; system-dependent for others mode: type: integer description: | File mode and permission bits. More details here: https://golang.org/pkg/io/fs/#FileMode. Let's see some examples: - for a directory mode&2147483648 != 0 - for a symlink mode&134217728 != 0 - for a regular file mode&2401763328 == 0 last_modified: type: string format: date-time StatVFS: type: object properties: bsize: type: integer description: file system block size frsize: type: integer description: fundamental fs block size blocks: type: integer description: number of blocks bfree: type: integer description: free blocks in file system bavail: type: integer description: free blocks for non-root files: type: integer description: total file inodes ffree: type: integer description: free file inodes favail: type: integer description: free file inodes for non-root fsid: type: integer description: file system id flag: type: integer description: bit mask of f_flag values namemax: type: integer description: maximum filename length securitySchemes: BasicAuth: type: http scheme: basic ApiKeyAuth: type: apiKey in: header name: X-API-KEY ================================================ FILE: openapi/openapi.yaml ================================================ openapi: 3.0.3 tags: - name: healthcheck - name: token - name: maintenance - name: admins - name: API keys - name: connections - name: IP Lists - name: defender - name: quota - name: folders - name: groups - name: roles - name: users - name: data retention - name: events - name: metadata - name: user APIs - name: public shares - name: event manager info: title: SFTPGo description: | SFTPGo allows you to securely share your files over SFTP and optionally over HTTP/S, FTP/S and WebDAV as well. Several storage backends are supported and they are configurable per-user, so you can serve a local directory for a user and an S3 bucket (or part of it) for another one. SFTPGo also supports virtual folders, a virtual folder can use any of the supported storage backends. So you can have, for example, a user with the S3 backend mapping a Google Cloud Storage bucket (or part of it) on a specified path and an encrypted local filesystem on another one. Virtual folders can be private or shared among multiple users, for shared virtual folders you can define different quota limits for each user. SFTPGo supports groups to simplify the administration of multiple accounts by letting you assign settings once to a group, instead of multiple times to each individual user. The SFTPGo WebClient allows end users to change their credentials, browse and manage their files in the browser and setup two-factor authentication which works with Authy, Google Authenticator and other compatible apps. From the WebClient each authorized user can also create HTTP/S links to externally share files and folders securely, by setting limits to the number of downloads/uploads, protecting the share with a password, limiting access by source IP address, setting an automatic expiration date. version: v2.7.0 contact: name: API support url: 'https://github.com/drakkan/sftpgo' license: name: AGPL-3.0-only url: 'https://www.gnu.org/licenses/agpl-3.0.en.html' servers: - url: /api/v2 security: - BearerAuth: [] - APIKeyAuth: [] paths: /healthz: get: security: [] servers: - url: / tags: - healthcheck summary: health check description: This endpoint can be used to check if the application is running and responding to requests operationId: healthz responses: '200': description: successful operation content: text/plain; charset=utf-8: schema: type: string example: ok /shares/{id}: parameters: - name: id in: path description: the share id required: true schema: type: string get: security: - BasicAuth: [] tags: - public shares summary: Download shared files and folders as a single zip file description: A zip file, containing the shared files and folders, will be generated on the fly and returned as response body. Only folders and regular files will be included in the zip. The share must be defined with the read scope and the associated user must have list and download permissions operationId: get_share parameters: - in: query name: compress schema: type: boolean default: true required: false responses: '200': description: successful operation content: '*/*': schema: type: string format: binary '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: security: - BasicAuth: [] tags: - public shares summary: Upload one or more files to the shared path description: The share must be defined with the write scope and the associated user must have the upload permission operationId: upload_to_share requestBody: content: multipart/form-data: schema: type: object properties: filenames: type: array items: type: string format: binary minItems: 1 uniqueItems: true required: true responses: '201': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '413': $ref: '#/components/responses/RequestEntityTooLarge' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /shares/{id}/files: parameters: - name: id in: path description: the share id required: true schema: type: string get: security: - BasicAuth: [] tags: - public shares summary: Download a single file description: Returns the file contents as response body. The share must have exactly one path defined and it must be a directory for this to work operationId: download_share_file parameters: - in: query name: path required: true description: Path to the file to download. It must be URL encoded, for example the path "my dir/àdir/file.txt" must be sent as "my%20dir%2F%C3%A0dir%2Ffile.txt" schema: type: string - in: query name: inline required: false description: 'If set, the response will not have the Content-Disposition header set to `attachment`' schema: type: string responses: '200': description: successful operation content: '*/*': schema: type: string format: binary '206': description: successful operation content: '*/*': schema: type: string format: binary '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /shares/{id}/dirs: parameters: - name: id in: path description: the share id required: true schema: type: string get: security: - BasicAuth: [] tags: - public shares summary: Read directory contents description: Returns the contents of the specified directory for the specified share. The share must have exactly one path defined and it must be a directory for this to work operationId: get_share_dir_contents parameters: - in: query name: path description: Path to the folder to read. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir". If empty or missing the user's start directory is assumed. If relative, the user's start directory is used as the base schema: type: string responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/DirEntry' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /shares/{id}/{fileName}: parameters: - name: id in: path description: the share id required: true schema: type: string - name: fileName in: path description: the name of the new file. It must be path encoded. Sub directories are not accepted required: true schema: type: string - name: X-SFTPGO-MTIME in: header schema: type: integer description: File modification time as unix timestamp in milliseconds post: security: - BasicAuth: [] tags: - public shares summary: Upload a single file to the shared path description: The share must be defined with the write scope and the associated user must have the upload/overwrite permissions operationId: upload_single_to_share requestBody: content: application/*: schema: type: string format: binary text/*: schema: type: string format: binary image/*: schema: type: string format: binary audio/*: schema: type: string format: binary video/*: schema: type: string format: binary required: true responses: '201': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '413': $ref: '#/components/responses/RequestEntityTooLarge' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /token: get: security: - BasicAuth: [] tags: - token summary: Get a new admin access token description: Returns an access token and its expiration operationId: get_token parameters: - in: header name: X-SFTPGO-OTP schema: type: string required: false description: 'If you have 2FA configured for the admin attempting to log in you need to set the authentication code using this header parameter' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/Token' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /logout: get: security: - BearerAuth: [] tags: - token summary: Invalidate an admin access token description: Allows to invalidate an admin token before its expiration operationId: logout responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/token: get: security: - BasicAuth: [] tags: - token summary: Get a new user access token description: Returns an access token and its expiration operationId: get_user_token parameters: - in: header name: X-SFTPGO-OTP schema: type: string required: false description: 'If you have 2FA configured, for the HTTP protocol, for the user attempting to log in you need to set the authentication code using this header parameter' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/Token' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/logout: get: security: - BearerAuth: [] tags: - token summary: Invalidate a user access token description: Allows to invalidate a client token before its expiration operationId: client_logout responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /version: get: tags: - maintenance summary: Get version details description: 'Returns version details such as the version number, build date, commit hash and enabled features' operationId: get_version responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/VersionInfo' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /admin/changepwd: put: security: - BearerAuth: [] tags: - admins summary: Change admin password description: Changes the password for the logged in admin operationId: change_admin_password requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/PwdChange' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /admin/profile: get: security: - BearerAuth: [] tags: - admins summary: Get admin profile description: 'Returns the profile for the logged in admin' operationId: get_admin_profile responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/AdminProfile' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: security: - BearerAuth: [] tags: - admins summary: Update admin profile description: 'Allows to update the profile for the logged in admin' operationId: update_admin_profile requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/AdminProfile' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /admin/2fa/recoverycodes: get: security: - BearerAuth: [] tags: - admins summary: Get recovery codes description: 'Returns the recovery codes for the logged in admin. Recovery codes can be used if the admin loses access to their second factor auth device. Recovery codes are returned unencrypted' operationId: get_admin_recovery_codes responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/RecoveryCode' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: security: - BearerAuth: [] tags: - admins summary: Generate recovery codes description: 'Generates new recovery codes for the logged in admin. Generating new recovery codes you automatically invalidate old ones' operationId: generate_admin_recovery_codes responses: '200': description: successful operation content: application/json: schema: type: array items: type: string '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /admin/totp/configs: get: security: - BearerAuth: [] tags: - admins summary: Get available TOTP configuration description: Returns the available TOTP configurations for the logged in admin operationId: get_admin_totp_configs responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/TOTPConfig' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /admin/totp/generate: post: security: - BearerAuth: [] tags: - admins summary: Generate a new TOTP secret description: 'Generates a new TOTP secret, including the QR code as png, using the specified configuration for the logged in admin' operationId: generate_admin_totp_secret requestBody: required: true content: application/json: schema: type: object properties: config_name: type: string description: 'name of the configuration to use to generate the secret' responses: '200': description: successful operation content: application/json: schema: type: object properties: config_name: type: string issuer: type: string secret: type: string url: type: string qr_code: type: string format: byte description: 'QR code png encoded as BASE64' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /admin/totp/validate: post: security: - BearerAuth: [] tags: - admins summary: Validate a one time authentication code description: 'Checks if the given authentication code can be validated using the specified secret and config name' operationId: validate_admin_totp_secret requestBody: required: true content: application/json: schema: type: object properties: config_name: type: string description: 'name of the configuration to use to validate the passcode' passcode: type: string description: 'passcode to validate' secret: type: string description: 'secret to use to validate the passcode' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Passcode successfully validated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /admin/totp/save: post: security: - BearerAuth: [] tags: - admins summary: Save a TOTP config description: 'Saves the specified TOTP config for the logged in admin' operationId: save_admin_totp_config requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/AdminTOTPConfig' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: TOTP configuration saved '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /connections: get: tags: - connections summary: Get connections details description: Returns the active users and info about their current uploads/downloads operationId: get_connections responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/ConnectionStatus' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/connections/{connectionID}': delete: tags: - connections summary: Close connection description: Terminates an active connection operationId: close_connection parameters: - name: connectionID in: path description: ID of the connection to close required: true schema: type: string responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Connection closed '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /iplists/{type}: parameters: - name: type in: path description: IP list type required: true schema: $ref: '#/components/schemas/IPListType' get: tags: - IP Lists summary: Get IP list entries description: Returns an array with one or more IP list entry operationId: get_ip_list_entries parameters: - in: query name: filter schema: type: string description: restrict results to ipornet matching or starting with this filter - in: query name: from schema: type: string description: ipornet to start from required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering entries by ipornet field. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/IPListEntry' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - IP Lists summary: Add a new IP list entry description: Add an IP address or a CIDR network to a supported list operationId: add_ip_list_entry requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/IPListEntry' responses: '201': description: successful operation headers: Location: schema: type: string description: 'URI of the newly created object' content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Entry added '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /iplists/{type}/{ipornet}: parameters: - name: type in: path description: IP list type required: true schema: $ref: '#/components/schemas/IPListType' - name: ipornet in: path required: true schema: type: string get: tags: - IP Lists summary: Find entry by ipornet description: Returns the entry with the given ipornet if it exists. operationId: get_ip_list_by_ipornet responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/IPListEntry' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - IP Lists summary: Update IP list entry description: Updates an existing IP list entry operationId: update_ip_list_entry requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/IPListEntry' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Entry updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - IP Lists summary: Delete IP list entry description: Deletes an existing IP list entry operationId: delete_ip_list_entry responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Entry deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /defender/hosts: get: tags: - defender summary: Get hosts description: Returns hosts that are banned or for which some violations have been detected operationId: get_defender_hosts responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/DefenderEntry' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /defender/hosts/{id}: parameters: - name: id in: path description: host id required: true schema: type: string get: tags: - defender summary: Get host by id description: Returns the host with the given id, if it exists operationId: get_defender_host_by_id responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/DefenderEntry' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - defender summary: Removes a host from the defender lists description: Unbans the specified host or clears its violations operationId: delete_defender_host_by_id responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /retention/users/checks: get: tags: - data retention summary: Get retention checks description: Returns the active retention checks operationId: get_users_retention_checks responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/RetentionCheck' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /quotas/users/scans: get: tags: - quota summary: Get active user quota scans description: Returns the active user quota scans operationId: get_users_quota_scans responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/QuotaScan' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /quotas/users/{username}/scan: parameters: - name: username in: path description: the username required: true schema: type: string post: tags: - quota summary: Start a user quota scan description: Starts a new quota scan for the given user. A quota scan updates the number of files and their total size for the specified user and the virtual folders, if any, included in his quota operationId: start_user_quota_scan responses: '202': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Scan started '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '409': $ref: '#/components/responses/Conflict' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /quotas/users/{username}/usage: parameters: - name: username in: path description: the username required: true schema: type: string - in: query name: mode required: false description: the update mode specifies if the given quota usage values should be added or replace the current ones schema: type: string enum: - add - reset description: | Update type: * `add` - add the specified quota limits to the current used ones * `reset` - reset the values to the specified ones. This is the default example: reset put: tags: - quota summary: Update disk quota usage limits description: Sets the current used quota limits for the given user operationId: user_quota_update_usage requestBody: required: true description: 'If used_quota_size and used_quota_files are missing they will default to 0, this means that if mode is "add" the current value, for the missing field, will remain unchanged, if mode is "reset" the missing field is set to 0' content: application/json: schema: $ref: '#/components/schemas/QuotaUsage' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Quota updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '409': $ref: '#/components/responses/Conflict' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /quotas/users/{username}/transfer-usage: parameters: - name: username in: path description: the username required: true schema: type: string - in: query name: mode required: false description: the update mode specifies if the given quota usage values should be added or replace the current ones schema: type: string enum: - add - reset description: | Update type: * `add` - add the specified quota limits to the current used ones * `reset` - reset the values to the specified ones. This is the default example: reset put: tags: - quota summary: Update transfer quota usage limits description: Sets the current used transfer quota limits for the given user operationId: user_transfer_quota_update_usage requestBody: required: true description: 'If used_upload_data_transfer and used_download_data_transfer are missing they will default to 0, this means that if mode is "add" the current value, for the missing field, will remain unchanged, if mode is "reset" the missing field is set to 0' content: application/json: schema: $ref: '#/components/schemas/TransferQuotaUsage' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Quota updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '409': $ref: '#/components/responses/Conflict' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /quotas/folders/scans: get: tags: - quota summary: Get active folder quota scans description: Returns the active folder quota scans operationId: get_folders_quota_scans responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/FolderQuotaScan' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /quotas/folders/{name}/scan: parameters: - name: name in: path description: folder name required: true schema: type: string post: tags: - quota summary: Start a folder quota scan description: Starts a new quota scan for the given folder. A quota scan update the number of files and their total size for the specified folder operationId: start_folder_quota_scan responses: '202': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Scan started '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '409': $ref: '#/components/responses/Conflict' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /quotas/folders/{name}/usage: parameters: - name: name in: path description: folder name required: true schema: type: string - in: query name: mode required: false description: the update mode specifies if the given quota usage values should be added or replace the current ones schema: type: string enum: - add - reset description: | Update type: * `add` - add the specified quota limits to the current used ones * `reset` - reset the values to the specified ones. This is the default example: reset put: tags: - quota summary: Update folder quota usage limits description: Sets the current used quota limits for the given folder operationId: folder_quota_update_usage parameters: - in: query name: mode required: false description: the update mode specifies if the given quota usage values should be added or replace the current ones schema: type: string enum: - add - reset description: | Update type: * `add` - add the specified quota limits to the current used ones * `reset` - reset the values to the specified ones. This is the default example: reset requestBody: required: true description: 'If used_quota_size and used_quota_files are missing they will default to 0, this means that if mode is "add" the current value, for the missing field, will remain unchanged, if mode is "reset" the missing field is set to 0' content: application/json: schema: $ref: '#/components/schemas/QuotaUsage' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Quota updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '409': $ref: '#/components/responses/Conflict' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /folders: get: tags: - folders summary: Get folders description: Returns an array with one or more folders operationId: get_folders parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering folders by name. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/BaseVirtualFolder' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - folders summary: Add folder operationId: add_folder description: Adds a new folder. A quota scan is required to update the used files/size parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/BaseVirtualFolder' responses: '201': description: successful operation headers: Location: schema: type: string description: 'URI of the newly created object' content: application/json: schema: $ref: '#/components/schemas/BaseVirtualFolder' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/folders/{name}': parameters: - name: name in: path description: folder name required: true schema: type: string get: tags: - folders summary: Find folders by name description: Returns the folder with the given name if it exists. operationId: get_folder_by_name parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/BaseVirtualFolder' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - folders summary: Update folder description: Updates an existing folder operationId: update_folder requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/BaseVirtualFolder' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Folder updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - folders summary: Delete folder description: Deletes an existing folder operationId: delete_folder responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: User deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /groups: get: tags: - groups summary: Get groups description: Returns an array with one or more groups operationId: get_groups parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering groups by name. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/Group' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - groups summary: Add group operationId: add_group description: Adds a new group parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/Group' responses: '201': description: successful operation headers: Location: schema: type: string description: 'URI of the newly created object' content: application/json: schema: $ref: '#/components/schemas/Group' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/groups/{name}': parameters: - name: name in: path description: group name required: true schema: type: string get: tags: - groups summary: Find groups by name description: Returns the group with the given name if it exists. operationId: get_group_by_name parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/Group' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - groups summary: Update group description: Updates an existing group operationId: update_group requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/Group' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Group updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - groups summary: Delete group description: Deletes an existing group operationId: delete_group responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Group deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /roles: get: tags: - roles summary: Get roles description: Returns an array with one or more roles operationId: get_roles parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering groups by name. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/Role' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - roles summary: Add role operationId: add_role description: Adds a new role requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/Role' responses: '201': description: successful operation headers: Location: schema: type: string description: 'URI of the newly created object' content: application/json: schema: $ref: '#/components/schemas/Role' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/roles/{name}': parameters: - name: name in: path description: role name required: true schema: type: string get: tags: - roles summary: Find roles by name description: Returns the role with the given name if it exists. operationId: get_role_by_name responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/Role' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - roles summary: Update role description: Updates an existing role operationId: update_role requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/Role' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Group updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - roles summary: Delete role description: Deletes an existing role operationId: delete_role responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Group deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /eventactions: get: tags: - event manager summary: Get event actions description: Returns an array with one or more event actions operationId: get_event_actons parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering actions by name. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/BaseEventAction' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - event manager summary: Add event action operationId: add_event_action description: Adds a new event actions parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/BaseEventAction' responses: '201': description: successful operation headers: Location: schema: type: string description: 'URI of the newly created object' content: application/json: schema: $ref: '#/components/schemas/BaseEventAction' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/eventactions/{name}': parameters: - name: name in: path description: action name required: true schema: type: string get: tags: - event manager summary: Find event actions by name description: Returns the event action with the given name if it exists. operationId: get_event_action_by_name parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/BaseEventAction' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - event manager summary: Update event action description: Updates an existing event action operationId: update_event_action requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/BaseEventAction' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Event action updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - event manager summary: Delete event action description: Deletes an existing event action operationId: delete_event_action responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Event action deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /eventrules: get: tags: - event manager summary: Get event rules description: Returns an array with one or more event rules operationId: get_event_rules parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering rules by name. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/EventRule' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - event manager summary: Add event rule operationId: add_event_rule description: Adds a new event rule parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/EventRuleMinimal' responses: '201': description: successful operation headers: Location: schema: type: string description: 'URI of the newly created object' content: application/json: schema: $ref: '#/components/schemas/EventRule' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/eventrules/{name}': parameters: - name: name in: path description: rule name required: true schema: type: string get: tags: - event manager summary: Find event rules by name description: Returns the event rule with the given name if it exists. operationId: get_event_rile_by_name parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/EventRule' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - event manager summary: Update event rule description: Updates an existing event rule operationId: update_event_rule requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/EventRuleMinimal' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Event rules updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - event manager summary: Delete event rule description: Deletes an existing event rule operationId: delete_event_rule responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Event rules deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/eventrules/run/{name}': parameters: - name: name in: path description: on-demand rule name required: true schema: type: string post: tags: - event manager summary: Run an on-demand event rule description: The rule's actions will run in background. SFTPGo will not monitor any concurrency and such. If you want to be notified at the end of the execution please add an appropriate action operationId: run_event_rule responses: '202': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Event rule started '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /events/fs: get: tags: - events summary: Get filesystem events description: 'Returns an array with one or more filesystem events applying the specified filters. This API is only available if you configure an "eventsearcher" plugin' operationId: get_fs_events parameters: - in: query name: start_timestamp schema: type: integer format: int64 minimum: 0 default: 0 required: false description: 'the event timestamp, unix timestamp in nanoseconds, must be greater than or equal to the specified one. 0 or missing means omit this filter' - in: query name: end_timestamp schema: type: integer format: int64 minimum: 0 default: 0 required: false description: 'the event timestamp, unix timestamp in nanoseconds, must be less than or equal to the specified one. 0 or missing means omit this filter' - in: query name: actions schema: type: array items: $ref: '#/components/schemas/FsEventAction' description: 'the event action must be included among those specified. Empty or missing means omit this filter. Actions must be specified comma separated' explode: false required: false - in: query name: username schema: type: string description: 'the event username must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: ip schema: type: string description: 'the event IP must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: ssh_cmd schema: type: string description: 'the event SSH command must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: fs_provider schema: $ref: '#/components/schemas/FsProviders' description: 'the event filesystem provider must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: bucket schema: type: string description: 'the bucket must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: endpoint schema: type: string description: 'the endpoint must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: protocols schema: type: array items: $ref: '#/components/schemas/EventProtocols' description: 'the event protocol must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' explode: false required: false - in: query name: statuses schema: type: array items: $ref: '#/components/schemas/FsEventStatus' description: 'the event status must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' explode: false required: false - in: query name: instance_ids schema: type: array items: type: string description: 'the event instance id must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' explode: false required: false - in: query name: from_id schema: type: string description: 'the event id to start from. This is useful for cursor based pagination. Empty or missing means omit this filter.' required: false - in: query name: role schema: type: string description: 'User role. Empty or missing means omit this filter. Ignored if the admin has a role' required: false - in: query name: csv_export schema: type: boolean default: false required: false description: 'If enabled, events are exported as a CSV file' - in: query name: limit schema: type: integer minimum: 1 maximum: 1000 default: 100 required: false description: 'The maximum number of items to return. Max value is 1000, default is 100' - in: query name: order required: false description: Ordering events by timestamp. Default DESC schema: type: string enum: - ASC - DESC example: DESC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/FsEvent' text/csv: schema: type: string '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /events/provider: get: tags: - events summary: Get provider events description: 'Returns an array with one or more provider events applying the specified filters. This API is only available if you configure an "eventsearcher" plugin' operationId: get_provider_events parameters: - in: query name: start_timestamp schema: type: integer format: int64 minimum: 0 default: 0 required: false description: 'the event timestamp, unix timestamp in nanoseconds, must be greater than or equal to the specified one. 0 or missing means omit this filter' - in: query name: end_timestamp schema: type: integer format: int64 minimum: 0 default: 0 required: false description: 'the event timestamp, unix timestamp in nanoseconds, must be less than or equal to the specified one. 0 or missing means omit this filter' - in: query name: actions schema: type: array items: $ref: '#/components/schemas/ProviderEventAction' description: 'the event action must be included among those specified. Empty or missing means omit this filter. Actions must be specified comma separated' explode: false required: false - in: query name: username schema: type: string description: 'the event username must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: ip schema: type: string description: 'the event IP must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: object_name schema: type: string description: 'the event object name must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: object_types schema: type: array items: $ref: '#/components/schemas/ProviderEventObjectType' description: 'the event object type must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' explode: false required: false - in: query name: instance_ids schema: type: array items: type: string description: 'the event instance id must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' explode: false required: false - in: query name: from_id schema: type: string description: 'the event id to start from. This is useful for cursor based pagination. Empty or missing means omit this filter.' required: false - in: query name: role schema: type: string description: 'Admin role. Empty or missing means omit this filter. Ignored if the admin has a role' required: false - in: query name: csv_export schema: type: boolean default: false required: false description: 'If enabled, events are exported as a CSV file' - in: query name: omit_object_data schema: type: boolean default: false required: false description: 'If enabled, returned events will not contain the `object_data` field' - in: query name: limit schema: type: integer minimum: 1 maximum: 1000 default: 100 required: false description: 'The maximum number of items to return. Max value is 1000, default is 100' - in: query name: order required: false description: Ordering events by timestamp. Default DESC schema: type: string enum: - ASC - DESC example: DESC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/ProviderEvent' text/csv: schema: type: string '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /events/logs: get: tags: - events summary: Get log events description: 'Returns an array with one or more log events applying the specified filters. This API is only available if you configure an "eventsearcher" plugin' operationId: get_log_events parameters: - in: query name: start_timestamp schema: type: integer format: int64 minimum: 0 default: 0 required: false description: 'the event timestamp, unix timestamp in nanoseconds, must be greater than or equal to the specified one. 0 or missing means omit this filter' - in: query name: end_timestamp schema: type: integer format: int64 minimum: 0 default: 0 required: false description: 'the event timestamp, unix timestamp in nanoseconds, must be less than or equal to the specified one. 0 or missing means omit this filter' - in: query name: events schema: type: array items: $ref: '#/components/schemas/LogEventType' description: 'the log events must be included among those specified. Empty or missing means omit this filter. Events must be specified comma separated' explode: false required: false - in: query name: username schema: type: string description: 'the event username must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: ip schema: type: string description: 'the event IP must be the same as the one specified. Empty or missing means omit this filter' required: false - in: query name: protocols schema: type: array items: $ref: '#/components/schemas/EventProtocols' description: 'the event protocol must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' explode: false required: false - in: query name: instance_ids schema: type: array items: type: string description: 'the event instance id must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' explode: false required: false - in: query name: from_id schema: type: string description: 'the event id to start from. This is useful for cursor based pagination. Empty or missing means omit this filter.' required: false - in: query name: role schema: type: string description: 'User role. Empty or missing means omit this filter. Ignored if the admin has a role' required: false - in: query name: csv_export schema: type: boolean default: false required: false description: 'If enabled, events are exported as a CSV file' - in: query name: limit schema: type: integer minimum: 1 maximum: 1000 default: 100 required: false description: 'The maximum number of items to return. Max value is 1000, default is 100' - in: query name: order required: false description: Ordering events by timestamp. Default DESC schema: type: string enum: - ASC - DESC example: DESC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/LogEvent' text/csv: schema: type: string '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /apikeys: get: security: - BearerAuth: [] tags: - API keys summary: Get API keys description: Returns an array with one or more API keys. For security reasons hashed keys are omitted in the response operationId: get_api_keys parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering API keys by id. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/APIKey' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: security: - BearerAuth: [] tags: - API keys summary: Add API key description: Adds a new API key operationId: add_api_key requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/APIKey' responses: '201': description: successful operation headers: X-Object-ID: schema: type: string description: ID for the new created API key Location: schema: type: string description: URI to retrieve the details for the new created API key content: application/json: schema: type: object properties: mesage: type: string example: 'API key created. This is the only time the API key is visible, please save it.' key: type: string description: 'generated API key' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/apikeys/{id}': parameters: - name: id in: path description: the key id required: true schema: type: string get: security: - BearerAuth: [] tags: - API keys summary: Find API key by id description: Returns the API key with the given id, if it exists. For security reasons the hashed key is omitted in the response operationId: get_api_key_by_id responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/APIKey' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: security: - BearerAuth: [] tags: - API keys summary: Update API key description: Updates an existing API key. You cannot update the key itself, the creation date and the last use operationId: update_api_key requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/APIKey' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: API key updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: security: - BearerAuth: [] tags: - API keys summary: Delete API key description: Deletes an existing API key operationId: delete_api_key responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Admin deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /admins: get: tags: - admins summary: Get admins description: Returns an array with one or more admins. For security reasons hashed passwords are omitted in the response operationId: get_admins parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering admins by username. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/Admin' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - admins summary: Add admin description: 'Adds a new admin. Recovery codes and TOTP configuration cannot be set using this API: each admin must use the specific APIs' operationId: add_admin requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/Admin' responses: '201': description: successful operation headers: Location: schema: type: string description: 'URI of the newly created object' content: application/json: schema: $ref: '#/components/schemas/Admin' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/admins/{username}': parameters: - name: username in: path description: the admin username required: true schema: type: string get: tags: - admins summary: Find admins by username description: 'Returns the admin with the given username, if it exists. For security reasons the hashed password is omitted in the response' operationId: get_admin_by_username responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/Admin' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - admins summary: Update admin description: 'Updates an existing admin. Recovery codes and TOTP configuration cannot be set/updated using this API: each admin must use the specific APIs. You are not allowed to update the admin impersonated using an API key' operationId: update_admin requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/Admin' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Admin updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - admins summary: Delete admin description: Deletes an existing admin operationId: delete_admin responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Admin deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/admins/{username}/2fa/disable': parameters: - name: username in: path description: the admin username required: true schema: type: string put: tags: - admins summary: Disable second factor authentication description: 'Disables second factor authentication for the given admin. This API must be used if the admin loses access to their second factor auth device and has no recovery codes' operationId: disable_admin_2fa responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: 2FA disabled '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/admins/{username}/forgot-password': parameters: - name: username in: path description: the admin username required: true schema: type: string post: security: [] tags: - admins summary: Send a password reset code by email description: 'You must set up an SMTP server and the account must have a valid email address, in which case SFTPGo will send a code via email to reset the password. If the specified admin does not exist, the request will be silently ignored (a success response will be returned) to avoid disclosing existing admins' operationId: admin_forgot_password responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/admins/{username}/reset-password': parameters: - name: username in: path description: the admin username required: true schema: type: string post: security: [] tags: - admins summary: Reset the password description: 'Set a new password using the code received via email' operationId: admin_reset_password requestBody: content: application/json: schema: type: object properties: code: type: string password: type: string required: true responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /users: get: tags: - users summary: Get users description: Returns an array with one or more users. For security reasons hashed passwords are omitted in the response operationId: get_users parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering users by username. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/User' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - users summary: Add user description: 'Adds a new user.Recovery codes and TOTP configuration cannot be set using this API: each user must use the specific APIs' operationId: add_user parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the hash of the password and the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/User' responses: '201': description: successful operation headers: Location: schema: type: string description: 'URI of the newly created object' content: application/json: schema: $ref: '#/components/schemas/User' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/users/{username}': parameters: - name: username in: path description: the username required: true schema: type: string get: tags: - users summary: Find users by username description: Returns the user with the given username if it exists. For security reasons the hashed password is omitted in the response operationId: get_user_by_username parameters: - in: query name: confidential_data schema: type: integer description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the hash of the password and the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/User' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - users summary: Update user description: 'Updates an existing user and optionally disconnects it, if connected, to apply the new settings. The current password will be preserved if the password field is omitted in the request body. Recovery codes and TOTP configuration cannot be set/updated using this API: each user must use the specific APIs' operationId: update_user parameters: - in: query name: disconnect schema: type: integer enum: - 0 - 1 description: | Disconnect: * `0` The user will not be disconnected and it will continue to use the old configuration until connected. This is the default * `1` The user will be disconnected after a successful update. It must login again and so it will be forced to use the new configuration requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/User' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: User updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - users summary: Delete user description: Deletes an existing user operationId: delete_user responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: User deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/users/{username}/2fa/disable': parameters: - name: username in: path description: the username required: true schema: type: string put: tags: - users summary: Disable second factor authentication description: 'Disables second factor authentication for the given user. This API must be used if the user loses access to their second factor auth device and has no recovery codes' operationId: disable_user_2fa responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: 2FA disabled '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/users/{username}/forgot-password': parameters: - name: username in: path description: the username required: true schema: type: string post: security: [] tags: - users summary: Send a password reset code by email description: 'You must configure an SMTP server, the account must have a valid email address and must not have the "reset-password-disabled" restriction, in which case SFTPGo will send a code via email to reset the password. If the specified user does not exist, the request will be silently ignored (a success response will be returned) to avoid disclosing existing users' operationId: user_forgot_password responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/users/{username}/reset-password': parameters: - name: username in: path description: the username required: true schema: type: string post: security: [] tags: - users summary: Reset the password description: 'Set a new password using the code received via email' operationId: user_reset_password requestBody: content: application/json: schema: type: object properties: code: type: string password: type: string required: true responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /status: get: tags: - maintenance summary: Get status description: Retrieves the status of the active services operationId: get_status responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ServicesStatus' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /dumpdata: get: tags: - maintenance summary: Dump data description: 'Backups data as data provider independent JSON. The backup can be saved in a local file on the server, to avoid exposing sensitive data over the network, or returned as response body. The output of dumpdata can be used as input for loaddata' operationId: dumpdata parameters: - in: query name: output-file schema: type: string description: Path for the file to write the JSON serialized data to. This path is relative to the configured "backups_path". If this file already exists it will be overwritten. To return the backup as response body set `output_data` to true instead. - in: query name: output-data schema: type: integer enum: - 0 - 1 description: | output data: * `0` or any other value != 1, the backup will be saved to a file on the server, `output_file` is required * `1` the backup will be returned as response body - in: query name: indent schema: type: integer enum: - 0 - 1 description: | indent: * `0` no indentation. This is the default * `1` format the output JSON - in: query name: scopes schema: type: array items: $ref: '#/components/schemas/DumpDataScopes' description: 'You can limit the dump contents to the specified scopes. Empty or missing means any supported scope. Scopes must be specified comma separated' explode: false required: false responses: '200': description: successful operation content: application/json: schema: oneOf: - $ref: '#/components/schemas/ApiResponse' - $ref: '#/components/schemas/BackupData' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /loaddata: parameters: - in: query name: scan-quota schema: type: integer enum: - 0 - 1 - 2 description: | Quota scan: * `0` no quota scan is done, the imported users/folders will have used_quota_size and used_quota_files = 0 or the existing values if they already exists. This is the default * `1` scan quota * `2` scan quota if the user has quota restrictions required: false - in: query name: mode schema: type: integer enum: - 0 - 1 - 2 description: | Mode: * `0` New objects are added, existing ones are updated. This is the default * `1` New objects are added, existing ones are not modified * `2` New objects are added, existing ones are updated and connected users are disconnected and so forced to use the new configuration get: tags: - maintenance summary: Load data from path description: 'Restores SFTPGo data from a JSON backup file on the server. Objects will be restored one by one and the restore is stopped if a object cannot be added or updated, so it could happen a partial restore' operationId: loaddata_from_file parameters: - in: query name: input-file schema: type: string required: true description: Path for the file to read the JSON serialized data from. This can be an absolute path or a path relative to the configured "backups_path". The max allowed file size is 10MB responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Data restored '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - maintenance summary: Load data description: 'Restores SFTPGo data from a JSON backup. Objects will be restored one by one and the restore is stopped if a object cannot be added or updated, so it could happen a partial restore' operationId: loaddata_from_request_body requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/BackupData' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Data restored '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/changepwd: put: security: - BearerAuth: [] tags: - user APIs summary: Change user password description: Changes the password for the logged in user operationId: change_user_password requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/PwdChange' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/profile: get: security: - BearerAuth: [] tags: - user APIs summary: Get user profile description: 'Returns the profile for the logged in user' operationId: get_user_profile responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/UserProfile' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: security: - BearerAuth: [] tags: - user APIs summary: Update user profile description: 'Allows to update the profile for the logged in user' operationId: update_user_profile requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/UserProfile' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/2fa/recoverycodes: get: security: - BearerAuth: [] tags: - user APIs summary: Get recovery codes description: 'Returns the recovery codes for the logged in user. Recovery codes can be used if the user loses access to their second factor auth device. Recovery codes are returned unencrypted' operationId: get_user_recovery_codes responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/RecoveryCode' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: security: - BearerAuth: [] tags: - user APIs summary: Generate recovery codes description: 'Generates new recovery codes for the logged in user. Generating new recovery codes you automatically invalidate old ones' operationId: generate_user_recovery_codes responses: '200': description: successful operation content: application/json: schema: type: array items: type: string '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/totp/configs: get: security: - BearerAuth: [] tags: - user APIs summary: Get available TOTP configuration description: Returns the available TOTP configurations for the logged in user operationId: get_user_totp_configs responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/TOTPConfig' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/totp/generate: post: security: - BearerAuth: [] tags: - user APIs summary: Generate a new TOTP secret description: 'Generates a new TOTP secret, including the QR code as png, using the specified configuration for the logged in user' operationId: generate_user_totp_secret requestBody: required: true content: application/json: schema: type: object properties: config_name: type: string description: 'name of the configuration to use to generate the secret' responses: '200': description: successful operation content: application/json: schema: type: object properties: config_name: type: string issuer: type: string secret: type: string qr_code: type: string format: byte description: 'QR code png encoded as BASE64' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/totp/validate: post: security: - BearerAuth: [] tags: - user APIs summary: Validate a one time authentication code description: 'Checks if the given authentication code can be validated using the specified secret and config name' operationId: validate_user_totp_secret requestBody: required: true content: application/json: schema: type: object properties: config_name: type: string description: 'name of the configuration to use to validate the passcode' passcode: type: string description: 'passcode to validate' secret: type: string description: 'secret to use to validate the passcode' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Passcode successfully validated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/totp/save: post: security: - BearerAuth: [] tags: - user APIs summary: Save a TOTP config description: 'Saves the specified TOTP config for the logged in user' operationId: save_user_totp_config requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/UserTOTPConfig' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: TOTP configuration saved '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/shares: get: tags: - user APIs summary: List user shares description: Returns the share for the logged in user operationId: get_user_shares parameters: - in: query name: offset schema: type: integer minimum: 0 default: 0 required: false - in: query name: limit schema: type: integer minimum: 1 maximum: 500 default: 100 required: false description: 'The maximum number of items to return. Max value is 500, default is 100' - in: query name: order required: false description: Ordering shares by ID. Default ASC schema: type: string enum: - ASC - DESC example: ASC responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/Share' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - user APIs summary: Add a share operationId: add_share description: 'Adds a new share. The share id will be auto-generated' requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/Share' responses: '201': description: successful operation headers: X-Object-ID: schema: type: string description: ID for the new created share Location: schema: type: string description: URI to retrieve the details for the new created share content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' '/user/shares/{id}': parameters: - name: id in: path description: the share id required: true schema: type: string get: tags: - user APIs summary: Get share by id description: Returns a share by id for the logged in user operationId: get_user_share_by_id responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/Share' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' put: tags: - user APIs summary: Update share description: 'Updates an existing share belonging to the logged in user' operationId: update_user_share requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/Share' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Share updated '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - user APIs summary: Delete share description: 'Deletes an existing share belonging to the logged in user' operationId: delete_user_share responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' example: message: Share deleted '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '404': $ref: '#/components/responses/NotFound' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/file-actions/copy: parameters: - in: query name: path description: Path to the file/folder to copy. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" schema: type: string required: true - in: query name: target description: New name. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" schema: type: string required: true post: tags: - user APIs summary: 'Copy a file or a directory' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/file-actions/move: parameters: - in: query name: path description: Path to the file/folder to rename. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" schema: type: string required: true - in: query name: target description: New name. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" schema: type: string required: true post: tags: - user APIs summary: 'Move (rename) a file or a directory' responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/dirs: get: tags: - user APIs summary: Read directory contents description: Returns the contents of the specified directory for the logged in user operationId: get_user_dir_contents parameters: - in: query name: path description: Path to the folder to read. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir". If empty or missing the user's start directory is assumed. If relative, the user's start directory is used as the base schema: type: string responses: '200': description: successful operation content: application/json: schema: type: array items: $ref: '#/components/schemas/DirEntry' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - user APIs summary: Create a directory description: Create a directory for the logged in user operationId: create_user_dir parameters: - in: query name: path description: Path to the folder to create. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" schema: type: string required: true - in: query name: mkdir_parents description: Create parent directories if they do not exist? schema: type: boolean required: false responses: '201': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' patch: tags: - user APIs deprecated: true summary: 'Rename a directory. Deprecated, use "file-actions/move"' description: Rename a directory for the logged in user. The rename is allowed for empty directory or for non empty local directories, with no virtual folders inside operationId: rename_user_dir parameters: - in: query name: path description: Path to the folder to rename. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" schema: type: string required: true - in: query name: target description: New name. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" schema: type: string required: true responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - user APIs summary: Delete a directory description: Delete a directory and any children it contains for the logged in user operationId: delete_user_dir parameters: - in: query name: path description: Path to the folder to delete. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" schema: type: string required: true responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/files: get: tags: - user APIs summary: Download a single file description: Returns the file contents as response body operationId: download_user_file parameters: - in: query name: path required: true description: Path to the file to download. It must be URL encoded, for example the path "my dir/àdir/file.txt" must be sent as "my%20dir%2F%C3%A0dir%2Ffile.txt" schema: type: string - in: query name: inline required: false description: 'If set, the response will not have the Content-Disposition header set to `attachment`' schema: type: string responses: '200': description: successful operation content: '*/*': schema: type: string format: binary '206': description: successful operation content: '*/*': schema: type: string format: binary '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' post: tags: - user APIs summary: Upload files description: Upload one or more files for the logged in user operationId: create_user_files parameters: - in: query name: path description: Parent directory for the uploaded files. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir". If empty or missing the root path is assumed. If a file with the same name already exists, it will be overwritten schema: type: string - in: query name: mkdir_parents description: Create parent directories if they do not exist? schema: type: boolean required: false requestBody: content: multipart/form-data: schema: type: object properties: filenames: type: array items: type: string format: binary minItems: 1 uniqueItems: true required: true responses: '201': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '413': $ref: '#/components/responses/RequestEntityTooLarge' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' patch: tags: - user APIs deprecated: true summary: Rename a file description: 'Rename a file for the logged in user. Deprecated, use "file-actions/move"' operationId: rename_user_file parameters: - in: query name: path description: Path to the file to rename. It must be URL encoded schema: type: string required: true - in: query name: target description: New name. It must be URL encoded schema: type: string required: true responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' delete: tags: - user APIs summary: Delete a file description: Delete a file for the logged in user. operationId: delete_user_file parameters: - in: query name: path description: Path to the file to delete. It must be URL encoded schema: type: string required: true responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/files/upload: post: tags: - user APIs summary: Upload a single file description: 'Upload a single file for the logged in user to an existing directory. This API does not use multipart/form-data and so no temporary files are created server side but only a single file can be uploaded as POST body' operationId: create_user_file parameters: - in: query name: path description: Full file path. It must be path encoded, for example the path "my dir/àdir/file.txt" must be sent as "my%20dir%2F%C3%A0dir%2Ffile.txt". The parent directory must exist. If a file with the same name already exists, it will be overwritten schema: type: string required: true - in: query name: mkdir_parents description: Create parent directories if they do not exist? schema: type: boolean required: false - in: header name: X-SFTPGO-MTIME schema: type: integer description: File modification time as unix timestamp in milliseconds requestBody: content: application/*: schema: type: string format: binary text/*: schema: type: string format: binary image/*: schema: type: string format: binary audio/*: schema: type: string format: binary video/*: schema: type: string format: binary required: true responses: '201': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '413': $ref: '#/components/responses/RequestEntityTooLarge' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/files/metadata: patch: tags: - user APIs summary: Set metadata for a file/directory description: 'Set supported metadata attributes for the specified file or directory' operationId: setprops_user_file parameters: - in: query name: path description: Full file/directory path. It must be URL encoded, for example the path "my dir/àdir/file.txt" must be sent as "my%20dir%2F%C3%A0dir%2Ffile.txt" schema: type: string required: true requestBody: content: application/json: schema: type: object properties: modification_time: type: integer description: File modification time as unix timestamp in milliseconds required: true responses: '200': description: successful operation content: application/json: schema: $ref: '#/components/schemas/ApiResponse' '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '413': $ref: '#/components/responses/RequestEntityTooLarge' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' /user/streamzip: post: tags: - user APIs summary: Download multiple files and folders as a single zip file description: A zip file, containing the specified files and folders, will be generated on the fly and returned as response body. Only folders and regular files will be included in the zip operationId: streamzip requestBody: required: true content: application/json: schema: type: array items: type: string description: Absolute file or folder path responses: '200': description: successful operation content: 'application/zip': schema: type: string format: binary '400': $ref: '#/components/responses/BadRequest' '401': $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' '500': $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' components: responses: BadRequest: description: Bad Request content: application/json: schema: $ref: '#/components/schemas/ApiResponse' Unauthorized: description: Unauthorized content: application/json: schema: $ref: '#/components/schemas/ApiResponse' Forbidden: description: Forbidden content: application/json: schema: $ref: '#/components/schemas/ApiResponse' NotFound: description: Not Found content: application/json: schema: $ref: '#/components/schemas/ApiResponse' Conflict: description: Conflict content: application/json: schema: $ref: '#/components/schemas/ApiResponse' RequestEntityTooLarge: description: Request Entity Too Large, max allowed size exceeded content: application/json: schema: $ref: '#/components/schemas/ApiResponse' InternalServerError: description: Internal Server Error content: application/json: schema: $ref: '#/components/schemas/ApiResponse' DefaultResponse: description: Unexpected Error content: application/json: schema: $ref: '#/components/schemas/ApiResponse' schemas: Permission: type: string enum: - '*' - list - download - upload - overwrite - delete - delete_files - delete_dirs - rename - rename_files - rename_dirs - create_dirs - create_symlinks - chmod - chown - chtimes - copy description: | Permissions: * `*` - all permissions are granted * `list` - list items is allowed * `download` - download files is allowed * `upload` - upload files is allowed * `overwrite` - overwrite an existing file, while uploading, is allowed. upload permission is required to allow file overwrite * `delete` - delete files or directories is allowed * `delete_files` - delete files is allowed * `delete_dirs` - delete directories is allowed * `rename` - rename files or directories is allowed * `rename_files` - rename files is allowed * `rename_dirs` - rename directories is allowed * `create_dirs` - create directories is allowed * `create_symlinks` - create links is allowed * `chmod` changing file or directory permissions is allowed * `chown` changing file or directory owner and group is allowed * `chtimes` changing file or directory access and modification time is allowed * `copy`, copying files or directories is allowed AdminPermissions: type: string enum: - '*' - add_users - edit_users - del_users - view_users - view_conns - close_conns - view_status - manage_folders - manage_groups - quota_scans - manage_defender - view_defender - view_events - disable_mfa description: | Admin permissions: * `*` - super admin permissions are granted * `add_users` - add new users is allowed * `edit_users` - change existing users is allowed * `del_users` - remove users is allowed * `view_users` - list users is allowed * `view_conns` - list active connections is allowed * `close_conns` - close active connections is allowed * `view_status` - view the server status is allowed * `manage_folders` - manage folders is allowed * `manage_groups` - manage groups is allowed * `quota_scans` - view and start quota scans is allowed * `manage_defender` - remove ip from the dynamic blocklist is allowed * `view_defender` - list the dynamic blocklist is allowed * `view_events` - view and search filesystem and provider events is allowed * `disable_mfa` - allow to disable two-factor authentication for users and admins FsProviders: type: integer enum: - 0 - 1 - 2 - 3 - 4 - 5 - 6 description: | Filesystem providers: * `0` - Local filesystem * `1` - S3 Compatible Object Storage * `2` - Google Cloud Storage * `3` - Azure Blob Storage * `4` - Local filesystem encrypted * `5` - SFTP * `6` - HTTP filesystem EventActionTypes: type: integer enum: - 1 - 2 - 3 - 4 - 5 - 6 - 7 - 8 - 9 - 11 - 12 - 13 - 14 - 15 description: | Supported event action types: * `1` - HTTP * `2` - Command * `3` - Email * `4` - Backup * `5` - User quota reset * `6` - Folder quota reset * `7` - Transfer quota reset * `8` - Data retention check * `9` - Filesystem * `11` - Password expiration check * `12` - User expiration check * `13` - Identity Provider account check * `14` - User inactivity check * `15` - Rotate log file FilesystemActionTypes: type: integer enum: - 1 - 2 - 3 - 4 - 5 - 6 description: | Supported filesystem action types: * `1` - Rename * `2` - Delete * `3` - Mkdis * `4` - Exist * `5` - Compress * `6` - Copy EventTriggerTypes: type: integer enum: - 1 - 2 - 3 - 4 - 5 - 6 - 7 description: | Supported event trigger types: * `1` - Filesystem event * `2` - Provider event * `3` - Schedule * `4` - IP blocked * `5` - Certificate renewal * `6` - On demand, like schedule but executed on demand * `7` - Identity provider login LoginMethods: type: string enum: - publickey - password - password-over-SSH - keyboard-interactive - publickey+password - publickey+keyboard-interactive - TLSCertificate - TLSCertificate+password description: | Available login methods. To enable multi-step authentication you have to allow only multi-step login methods * `publickey` * `password`, password for all the supported protocols * `password-over-SSH`, password over SSH protocol (SSH/SFTP/SCP) * `keyboard-interactive` * `publickey+password` - multi-step auth: public key and password * `publickey+keyboard-interactive` - multi-step auth: public key and keyboard interactive * `TLSCertificate` * `TLSCertificate+password` - multi-step auth: TLS client certificate and password SupportedProtocols: type: string enum: - SSH - FTP - DAV - HTTP description: | Protocols: * `SSH` - includes both SFTP and SSH commands * `FTP` - plain FTP and FTPES/FTPS * `DAV` - WebDAV over HTTP/HTTPS * `HTTP` - WebClient/REST API MFAProtocols: type: string enum: - SSH - FTP - HTTP description: | Protocols: * `SSH` - includes both SFTP and SSH commands * `FTP` - plain FTP and FTPES/FTPS * `HTTP` - WebClient/REST API EventProtocols: type: string enum: - SSH - SFTP - SCP - FTP - DAV - HTTP - HTTPShare - DataRetention - EventAction - OIDC description: | Protocols: * `SSH` - SSH commands * `SFTP` - SFTP protocol * `SCP` - SCP protocol * `FTP` - plain FTP and FTPES/FTPS * `DAV` - WebDAV * `HTTP` - WebClient/REST API * `HTTPShare` - the event is generated in a public share * `DataRetention` - the event is generated by a data retention check * `EventAction` - the event is generated by an EventManager action * `OIDC` - OpenID Connect WebClientOptions: type: string enum: - publickey-change-disabled - tls-cert-change-disabled - write-disabled - mfa-disabled - password-change-disabled - api-key-auth-change-disabled - info-change-disabled - shares-disabled - password-reset-disabled - shares-without-password-disabled description: | Options: * `publickey-change-disabled` - changing SSH public keys is not allowed * `tls-cert-change-disabled` - changing TLS certificates is not allowed * `write-disabled` - upload, rename, delete are not allowed even if the user has permissions for these actions * `mfa-disabled` - enabling multi-factor authentication is not allowed. This option cannot be set if the user has MFA already enabled * `password-change-disabled` - changing password is not allowed * `api-key-auth-change-disabled` - enabling/disabling API key authentication is not allowed * `info-change-disabled` - changing info such as email and description is not allowed * `shares-disabled` - sharing files and directories with external users is not allowed * `password-reset-disabled` - resetting the password is not allowed * `shares-without-password-disabled` - creating shares without password protection is not allowed APIKeyScope: type: integer enum: - 1 - 2 description: | Options: * `1` - admin scope. The API key will be used to impersonate an SFTPGo admin * `2` - user scope. The API key will be used to impersonate an SFTPGo user ShareScope: type: integer enum: - 1 - 2 description: | Options: * `1` - read scope * `2` - write scope TOTPHMacAlgo: type: string enum: - sha1 - sha256 - sha512 description: 'Supported HMAC algorithms for Time-based one time passwords' UserType: type: string enum: - '' - LDAPUser - OSUser description: This is an hint for authentication plugins. It is ignored when using SFTPGo internal authentication DumpDataScopes: type: string enum: - users - folders - groups - admins - api_keys - shares - actions - rules - roles - ip_lists - configs LogEventType: type: integer enum: - 1 - 2 - 3 - 4 - 5 description: > Event status: * `1` - Login failed * `2` - Login failed non-existent user * `3` - No login tried * `4` - Algorithm negotiation failed * `5` - Login succeeded FsEventStatus: type: integer enum: - 1 - 2 - 3 description: > Event status: * `1` - no error * `2` - generic error * `3` - quota exceeded error FsEventAction: type: string enum: - download - upload - first-upload - first-download - delete - rename - mkdir - rmdir - ssh_cmd ProviderEventAction: type: string enum: - add - update - delete ProviderEventObjectType: type: string enum: - user - folder - group - admin - api_key - share - event_action - event_rule - role SSHAuthentications: type: string enum: - publickey - password - keyboard-interactive - publickey+password - publickey+keyboard-interactive TLSVersions: type: integer enum: - 12 - 13 description: > TLS version: * `12` - TLS 1.2 * `13` - TLS 1.3 IPListType: type: integer enum: - 1 - 2 - 3 description: > IP List types: * `1` - allow list * `2` - defender * `3` - rate limiter safe list IPListMode: type: integer enum: - 1 - 2 description: > IP list modes * `1` - allow * `2` - deny, supported for defender list type only TOTPConfig: type: object properties: name: type: string issuer: type: string algo: $ref: '#/components/schemas/TOTPHMacAlgo' RecoveryCode: type: object properties: secret: $ref: '#/components/schemas/Secret' used: type: boolean description: 'Recovery codes to use if the user loses access to their second factor auth device. Each code can only be used once, you should use these codes to login and disable or reset 2FA for your account' BaseTOTPConfig: type: object properties: enabled: type: boolean config_name: type: string description: 'This name must be defined within the "totp" section of the SFTPGo configuration file. You will be unable to save a user/admin referencing a missing config_name' secret: $ref: '#/components/schemas/Secret' AdminTOTPConfig: allOf: - $ref: '#/components/schemas/BaseTOTPConfig' UserTOTPConfig: allOf: - $ref: '#/components/schemas/BaseTOTPConfig' - type: object properties: protocols: type: array items: $ref: '#/components/schemas/MFAProtocols' description: 'TOTP will be required for the specified protocols. SSH protocol (SFTP/SCP/SSH commands) will ask for the TOTP passcode if the client uses keyboard interactive authentication. FTP has no standard way to support two factor authentication, if you enable the FTP support, you have to add the TOTP passcode after the password. For example if your password is "password" and your one time passcode is "123456" you have to use "password123456" as password. WebDAV is not supported since each single request must be authenticated and a passcode cannot be reused.' PatternsFilter: type: object properties: path: type: string description: 'virtual path as seen by users, if no other specific filter is defined, the filter applies for sub directories too. For example if filters are defined for the paths "/" and "/sub" then the filters for "/" are applied for any file outside the "/sub" directory' allowed_patterns: type: array items: type: string description: 'list of, case insensitive, allowed shell like patterns. Allowed patterns are evaluated before the denied ones' example: - '*.jpg' - a*b?.png denied_patterns: type: array items: type: string description: 'list of, case insensitive, denied shell like patterns' example: - '*.zip' deny_policy: type: integer enum: - 0 - 1 description: | Policies for denied patterns * `0` - default policy. Denied files/directories matching the filters are visible in directory listing but cannot be uploaded/downloaded/overwritten/renamed * `1` - deny policy hide. This policy applies the same restrictions as the default one and denied files/directories matching the filters will also be hidden in directory listing. This mode may cause performance issues for large directories HooksFilter: type: object properties: external_auth_disabled: type: boolean example: false description: If true, the external auth hook, if defined, will not be executed pre_login_disabled: type: boolean example: false description: If true, the pre-login hook, if defined, will not be executed check_password_disabled: type: boolean example: false description: If true, the check password hook, if defined, will not be executed description: User specific hook overrides BandwidthLimit: type: object properties: sources: type: array items: type: string description: 'Source networks in CIDR notation as defined in RFC 4632 and RFC 4291 for example `192.0.2.0/24` or `2001:db8::/32`. The limit applies if the defined networks contain the client IP' upload_bandwidth: type: integer format: int32 description: 'Maximum upload bandwidth as KB/s, 0 means unlimited' download_bandwidth: type: integer format: int32 description: 'Maximum download bandwidth as KB/s, 0 means unlimited' TimePeriod: type: object properties: day_of_week: type: integer enum: - 0 - 1 - 2 - 3 - 4 - 5 - 6 description: Day of week, 0 Sunday, 6 Saturday from: type: string description: Start time in HH:MM format to: type: string description: End time in HH:MM format BaseUserFilters: type: object properties: allowed_ip: type: array items: type: string description: 'only clients connecting from these IP/Mask are allowed. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32"' example: - 192.0.2.0/24 - '2001:db8::/32' denied_ip: type: array items: type: string description: clients connecting from these IP/Mask are not allowed. Denied rules are evaluated before allowed ones example: - 172.16.0.0/16 denied_login_methods: type: array items: $ref: '#/components/schemas/LoginMethods' description: if null or empty any available login method is allowed denied_protocols: type: array items: $ref: '#/components/schemas/SupportedProtocols' description: if null or empty any available protocol is allowed file_patterns: type: array items: $ref: '#/components/schemas/PatternsFilter' description: 'filters based on shell like file patterns. These restrictions do not apply to files listing for performance reasons, so a denied file cannot be downloaded/overwritten/renamed but it will still be in the list of files. Please note that these restrictions can be easily bypassed' max_upload_file_size: type: integer format: int64 description: 'maximum allowed size, as bytes, for a single file upload. The upload will be aborted if/when the size of the file being sent exceeds this limit. 0 means unlimited' tls_username: type: string description: 'defines the TLS certificate field to use as username. For FTP clients it must match the name provided using the "USER" command. For WebDAV, if no username is provided, the CN will be used as username. For WebDAV clients it must match the implicit or provided username. Ignored if mutual TLS is disabled. Currently the only supported value is `CommonName`' hooks: $ref: '#/components/schemas/HooksFilter' disable_fs_checks: type: boolean example: false description: Disable checks for existence and automatic creation of home directory and virtual folders. SFTPGo requires that the user's home directory, virtual folder root, and intermediate paths to virtual folders exist to work properly. If you already know that the required directories exist, disabling these checks will speed up login. You could, for example, disable these checks after the first login web_client: type: array items: $ref: '#/components/schemas/WebClientOptions' description: WebClient/user REST API related configuration options allow_api_key_auth: type: boolean description: 'API key authentication allows to impersonate this user with an API key' user_type: $ref: '#/components/schemas/UserType' bandwidth_limits: type: array items: $ref: '#/components/schemas/BandwidthLimit' external_auth_cache_time: type: integer description: 'Defines the cache time, in seconds, for users authenticated using an external auth hook. 0 means no cache' start_directory: type: string description: 'Specifies an alternate starting directory. If not set, the default is "/". This option is supported for SFTP/SCP, FTP and HTTP (WebClient/REST API) protocols. Relative paths will use this directory as base.' two_factor_protocols: type: array items: $ref: '#/components/schemas/MFAProtocols' description: 'Defines protocols that require two factor authentication' ftp_security: type: integer enum: - 0 - 1 description: 'Set to `1` to require TLS for both data and control connection. his setting is useful if you want to allow both encrypted and plain text FTP sessions globally and then you want to require encrypted sessions on a per-user basis. It has no effect if TLS is already required for all users in the configuration file.' is_anonymous: type: boolean description: 'If enabled the user can login with any password or no password at all. Anonymous users are supported for FTP and WebDAV protocols and permissions will be automatically set to "list" and "download" (read only)' default_shares_expiration: type: integer description: 'Defines the default expiration for newly created shares as number of days. 0 means no expiration' max_shares_expiration: type: integer description: 'Defines the maximum allowed expiration, as a number of days, when a user creates or updates a share. 0 means no expiration' password_expiration: type: integer description: 'The password expires after the defined number of days. 0 means no expiration' password_strength: type: integer description: 'Defines the minimum password strength. 0 means disabled, any password will be accepted. Values in the 50-70 range are suggested for common use cases' access_time: type: array items: $ref: '#/components/schemas/TimePeriod' description: Additional user options UserFilters: allOf: - $ref: '#/components/schemas/BaseUserFilters' - type: object properties: require_password_change: type: boolean description: 'User must change password from WebClient/REST API at next login' totp_config: $ref: '#/components/schemas/UserTOTPConfig' recovery_codes: type: array items: $ref: '#/components/schemas/RecoveryCode' tls_certs: type: array items: type: string additional_emails: type: array items: type: string format: email Secret: type: object properties: status: type: string enum: - Plain - AES-256-GCM - Secretbox - GCP - AWS - VaultTransit - AzureKeyVault - Redacted description: 'Set to "Plain" to add or update an existing secret, set to "Redacted" to preserve the existing value' payload: type: string key: type: string additional_data: type: string mode: type: integer description: 1 means encrypted using a master key description: The secret is encrypted before saving, so to set a new secret you must provide a payload and set the status to "Plain". The encryption key and additional data will be generated automatically. If you set the status to "Redacted" the existing secret will be preserved S3Config: type: object properties: bucket: type: string minLength: 1 region: type: string minLength: 1 access_key: type: string access_secret: $ref: '#/components/schemas/Secret' sse_customer_key: $ref: '#/components/schemas/Secret' role_arn: type: string description: 'Optional IAM Role ARN to assume' session_token: type: string description: 'Optional Session token that is a part of temporary security credentials provisioned by AWS STS' endpoint: type: string description: optional endpoint storage_class: type: string acl: type: string description: 'The canned ACL to apply to uploaded objects. Leave empty to use the default ACL. For more information and available ACLs, see here: https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl' upload_part_size: type: integer description: 'the buffer size (in MB) to use for multipart uploads. The minimum allowed part size is 5MB, and if this value is set to zero, the default value (5MB) for the AWS SDK will be used. The minimum allowed value is 5.' upload_concurrency: type: integer description: 'the number of parts to upload in parallel. If this value is set to zero, the default value (5) will be used' upload_part_max_time: type: integer description: 'the maximum time allowed, in seconds, to upload a single chunk (the chunk size is defined via "upload_part_size"). 0 means no timeout' download_part_size: type: integer description: 'the buffer size (in MB) to use for multipart downloads. The minimum allowed part size is 5MB, and if this value is set to zero, the default value (5MB) for the AWS SDK will be used. The minimum allowed value is 5. Ignored for partial downloads' download_concurrency: type: integer description: 'the number of parts to download in parallel. If this value is set to zero, the default value (5) will be used. Ignored for partial downloads' download_part_max_time: type: integer description: 'the maximum time allowed, in seconds, to download a single chunk (the chunk size is defined via "download_part_size"). 0 means no timeout. Ignored for partial downloads.' force_path_style: type: boolean description: 'Set this to "true" to force the request to use path-style addressing, i.e., "http://s3.amazonaws.com/BUCKET/KEY". By default, the S3 client will use virtual hosted bucket addressing when possible ("http://BUCKET.s3.amazonaws.com/KEY")' key_prefix: type: string description: 'key_prefix is similar to a chroot directory for a local filesystem. If specified the user will only see contents that starts with this prefix and so you can restrict access to a specific virtual folder. The prefix, if not empty, must not start with "/" and must end with "/". If empty the whole bucket contents will be available' example: folder/subfolder/ description: S3 Compatible Object Storage configuration details GCSConfig: type: object properties: bucket: type: string minLength: 1 credentials: $ref: '#/components/schemas/Secret' automatic_credentials: type: integer enum: - 0 - 1 description: | Automatic credentials: * `0` - disabled, explicit credentials, using a JSON credentials file, must be provided. This is the default value if the field is null * `1` - enabled, we try to use the Application Default Credentials (ADC) strategy to find your application's credentials storage_class: type: string acl: type: string description: 'The ACL to apply to uploaded objects. Leave empty to use the default ACL. For more information and available ACLs, refer to the JSON API here: https://cloud.google.com/storage/docs/access-control/lists#predefined-acl' key_prefix: type: string description: 'key_prefix is similar to a chroot directory for a local filesystem. If specified the user will only see contents that starts with this prefix and so you can restrict access to a specific virtual folder. The prefix, if not empty, must not start with "/" and must end with "/". If empty the whole bucket contents will be available' example: folder/subfolder/ upload_part_size: type: integer description: 'The buffer size (in MB) to use for multipart uploads. The default value is 16MB. 0 means use the default' upload_part_max_time: type: integer description: 'The maximum time allowed, in seconds, to upload a single chunk. The default value is 32. 0 means use the default' description: 'Google Cloud Storage configuration details. The "credentials" field must be populated only when adding/updating a user. It will be always omitted, since there are sensitive data, when you search/get users' AzureBlobFsConfig: type: object properties: container: type: string account_name: type: string description: 'Storage Account Name, leave blank to use SAS URL' account_key: $ref: '#/components/schemas/Secret' sas_url: $ref: '#/components/schemas/Secret' endpoint: type: string description: 'optional endpoint. Default is "blob.core.windows.net". If you use the emulator the endpoint must include the protocol, for example "http://127.0.0.1:10000"' upload_part_size: type: integer description: 'the buffer size (in MB) to use for multipart uploads. If this value is set to zero, the default value (5MB) will be used.' upload_concurrency: type: integer description: 'the number of parts to upload in parallel. If this value is set to zero, the default value (5) will be used' download_part_size: type: integer description: 'the buffer size (in MB) to use for multipart downloads. If this value is set to zero, the default value (5MB) will be used.' download_concurrency: type: integer description: 'the number of parts to download in parallel. If this value is set to zero, the default value (5) will be used' access_tier: type: string enum: - '' - Archive - Hot - Cool key_prefix: type: string description: 'key_prefix is similar to a chroot directory for a local filesystem. If specified the user will only see contents that starts with this prefix and so you can restrict access to a specific virtual folder. The prefix, if not empty, must not start with "/" and must end with "/". If empty the whole container contents will be available' example: folder/subfolder/ use_emulator: type: boolean description: Azure Blob Storage configuration details OSFsConfig: type: object properties: read_buffer_size: type: integer minimum: 0 maximum: 10 description: "The read buffer size, as MB, to use for downloads. 0 means no buffering, that's fine in most use cases." write_buffer_size: type: integer minimum: 0 maximum: 10 description: "The write buffer size, as MB, to use for uploads. 0 means no buffering, that's fine in most use cases." CryptFsConfig: type: object properties: passphrase: $ref: '#/components/schemas/Secret' read_buffer_size: type: integer minimum: 0 maximum: 10 description: "The read buffer size, as MB, to use for downloads. 0 means no buffering, that's fine in most use cases." write_buffer_size: type: integer minimum: 0 maximum: 10 description: "The write buffer size, as MB, to use for uploads. 0 means no buffering, that's fine in most use cases." description: Crypt filesystem configuration details SFTPFsConfig: type: object properties: endpoint: type: string description: 'remote SFTP endpoint as host:port' username: type: string description: you can specify a password or private key or both. In the latter case the private key will be tried first. password: $ref: '#/components/schemas/Secret' private_key: $ref: '#/components/schemas/Secret' key_passphrase: $ref: '#/components/schemas/Secret' fingerprints: type: array items: type: string description: 'SHA256 fingerprints to use for host key verification. If you don''t provide any fingerprint the remote host key will not be verified, this is a security risk' prefix: type: string description: Specifying a prefix you can restrict all operations to a given path within the remote SFTP server. disable_concurrent_reads: type: boolean description: Concurrent reads are safe to use and disabling them will degrade performance. Some servers automatically delete files once they are downloaded. Using concurrent reads is problematic with such servers. buffer_size: type: integer minimum: 0 maximum: 16 example: 2 description: The size of the buffer (in MB) to use for transfers. By enabling buffering, the reads and writes, from/to the remote SFTP server, are split in multiple concurrent requests and this allows data to be transferred at a faster rate, over high latency networks, by overlapping round-trip times. With buffering enabled, resuming uploads is not supported and a file cannot be opened for both reading and writing at the same time. 0 means disabled. equality_check_mode: type: integer enum: - 0 - 1 description: | Defines how to check if this config points to the same server as another config. If different configs point to the same server the renaming between the fs configs is allowed: * `0` username and endpoint must match. This is the default * `1` only the endpoint must match HTTPFsConfig: type: object properties: endpoint: type: string description: 'HTTP/S endpoint URL. SFTPGo will use this URL as base, for example for the `stat` API, SFTPGo will add `/stat/{name}`' username: type: string password: $ref: '#/components/schemas/Secret' api_key: $ref: '#/components/schemas/Secret' skip_tls_verify: type: boolean equality_check_mode: type: integer enum: - 0 - 1 description: | Defines how to check if this config points to the same server as another config. If different configs point to the same server the renaming between the fs configs is allowed: * `0` username and endpoint must match. This is the default * `1` only the endpoint must match FilesystemConfig: type: object properties: provider: $ref: '#/components/schemas/FsProviders' osconfig: $ref: '#/components/schemas/OSFsConfig' s3config: $ref: '#/components/schemas/S3Config' gcsconfig: $ref: '#/components/schemas/GCSConfig' azblobconfig: $ref: '#/components/schemas/AzureBlobFsConfig' cryptconfig: $ref: '#/components/schemas/CryptFsConfig' sftpconfig: $ref: '#/components/schemas/SFTPFsConfig' httpconfig: $ref: '#/components/schemas/HTTPFsConfig' description: Storage filesystem details BaseVirtualFolder: type: object properties: id: type: integer format: int32 minimum: 1 name: type: string description: unique name for this virtual folder mapped_path: type: string description: absolute filesystem path to use as virtual folder description: type: string description: optional description used_quota_size: type: integer format: int64 used_quota_files: type: integer format: int32 last_quota_update: type: integer format: int64 description: Last quota update as unix timestamp in milliseconds users: type: array items: type: string description: list of usernames associated with this virtual folder filesystem: $ref: '#/components/schemas/FilesystemConfig' description: 'Defines the filesystem for the virtual folder and the used quota limits. The same folder can be shared among multiple users and each user can have different quota limits or a different virtual path.' VirtualFolder: allOf: - $ref: '#/components/schemas/BaseVirtualFolder' - type: object properties: virtual_path: type: string quota_size: type: integer format: int64 description: 'Quota as size in bytes. 0 means unlimited, -1 means included in user quota. Please note that quota is updated if files are added/removed via SFTPGo otherwise a quota scan or a manual quota update is needed' quota_files: type: integer format: int32 description: 'Quota as number of files. 0 means unlimited, , -1 means included in user quota. Please note that quota is updated if files are added/removed via SFTPGo otherwise a quota scan or a manual quota update is needed' required: - virtual_path description: 'A virtual folder is a mapping between a SFTPGo virtual path and a filesystem path outside the user home directory. The specified paths must be absolute and the virtual path cannot be "/", it must be a sub directory. The parent directory for the specified virtual path must exist. SFTPGo will try to automatically create any missing parent directory for the configured virtual folders at user login.' User: type: object properties: id: type: integer format: int32 minimum: 1 status: type: integer enum: - 0 - 1 description: | status: * `0` user is disabled, login is not allowed * `1` user is enabled username: type: string description: username is unique email: type: string format: email description: type: string description: 'optional description, for example the user full name' expiration_date: type: integer format: int64 description: expiration date as unix timestamp in milliseconds. An expired account cannot login. 0 means no expiration password: type: string format: password description: If the password has no known hashing algo prefix it will be stored, by default, using bcrypt, argon2id is supported too. You can send a password hashed as bcrypt ($2a$ prefix), argon2id, pbkdf2 or unix crypt and it will be stored as is. For security reasons this field is omitted when you search/get users public_keys: type: array items: type: string example: ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUWwDwEWhTbF0MqAsp/oXK1HR2cElhM8oo1uVmL3ZeDKDiTm4ljMr92wfTgIGDqIoxmVqgYIkAOAhuykAVWBzc= user@host description: Public keys in OpenSSH format. has_password: type: boolean description: Indicates whether the password is set home_dir: type: string description: path to the user home directory. The user cannot upload or download files outside this directory. SFTPGo tries to automatically create this folder if missing. Must be an absolute path virtual_folders: type: array items: $ref: '#/components/schemas/VirtualFolder' description: mapping between virtual SFTPGo paths and virtual folders uid: type: integer format: int32 minimum: 0 maximum: 2147483647 description: 'if you run SFTPGo as root user, the created files and directories will be assigned to this uid. 0 means no change, the owner will be the user that runs SFTPGo. Ignored on windows' gid: type: integer format: int32 minimum: 0 maximum: 2147483647 description: 'if you run SFTPGo as root user, the created files and directories will be assigned to this gid. 0 means no change, the group will be the one of the user that runs SFTPGo. Ignored on windows' max_sessions: type: integer format: int32 description: Limit the sessions that a user can open. 0 means unlimited quota_size: type: integer format: int64 description: Quota as size in bytes. 0 means unlimited. Please note that quota is updated if files are added/removed via SFTPGo otherwise a quota scan or a manual quota update is needed quota_files: type: integer format: int32 description: Quota as number of files. 0 means unlimited. Please note that quota is updated if files are added/removed via SFTPGo otherwise a quota scan or a manual quota update is needed permissions: type: object additionalProperties: type: array items: $ref: '#/components/schemas/Permission' minItems: 1 minProperties: 1 description: 'hash map with directory as key and an array of permissions as value. Directories must be absolute paths, permissions for root directory ("/") are required' example: /: - '*' /somedir: - list - download used_quota_size: type: integer format: int64 used_quota_files: type: integer format: int32 last_quota_update: type: integer format: int64 description: Last quota update as unix timestamp in milliseconds upload_bandwidth: type: integer description: 'Maximum upload bandwidth as KB/s, 0 means unlimited' download_bandwidth: type: integer description: 'Maximum download bandwidth as KB/s, 0 means unlimited' upload_data_transfer: type: integer description: 'Maximum data transfer allowed for uploads as MB. 0 means no limit' download_data_transfer: type: integer description: 'Maximum data transfer allowed for downloads as MB. 0 means no limit' total_data_transfer: type: integer description: 'Maximum total data transfer as MB. 0 means unlimited. You can set a total data transfer instead of the individual values for uploads and downloads' used_upload_data_transfer: type: integer description: 'Uploaded size, as bytes, since the last reset' used_download_data_transfer: type: integer description: 'Downloaded size, as bytes, since the last reset' created_at: type: integer format: int64 description: 'creation time as unix timestamp in milliseconds. It will be 0 for users created before v2.2.0' updated_at: type: integer format: int64 description: last update time as unix timestamp in milliseconds last_login: type: integer format: int64 description: Last user login as unix timestamp in milliseconds. It is saved at most once every 10 minutes first_download: type: integer format: int64 description: first download time as unix timestamp in milliseconds first_upload: type: integer format: int64 description: first upload time as unix timestamp in milliseconds last_password_change: type: integer format: int64 description: last password change time as unix timestamp in milliseconds filters: $ref: '#/components/schemas/UserFilters' filesystem: $ref: '#/components/schemas/FilesystemConfig' additional_info: type: string description: Free form text field for external systems groups: type: array items: $ref: '#/components/schemas/GroupMapping' oidc_custom_fields: type: object additionalProperties: true description: 'This field is passed to the pre-login hook if custom OIDC token fields have been configured. Field values can be of any type (this is a free form object) and depend on the type of the configured OIDC token fields' role: type: string AdminPreferences: type: object properties: hide_user_page_sections: type: integer description: 'Allow to hide some sections from the user page. These are not security settings and are not enforced server side in any way. They are only intended to simplify the user page in the WebAdmin UI. 1 means hide groups section, 2 means hide filesystem section, "users_base_dir" must be set in the config file otherwise this setting is ignored, 4 means hide virtual folders section, 8 means hide profile section, 16 means hide ACLs section, 32 means hide disk and bandwidth quota limits section, 64 means hide advanced settings section. The settings can be combined' default_users_expiration: type: integer description: 'Defines the default expiration for newly created users as number of days. 0 means no expiration' AdminFilters: type: object properties: allow_list: type: array items: type: string description: 'only clients connecting from these IP/Mask are allowed. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32"' example: - 192.0.2.0/24 - '2001:db8::/32' allow_api_key_auth: type: boolean description: 'API key auth allows to impersonate this administrator with an API key' require_two_factor: type: boolean require_password_change: type: boolean totp_config: $ref: '#/components/schemas/AdminTOTPConfig' recovery_codes: type: array items: $ref: '#/components/schemas/RecoveryCode' preferences: $ref: '#/components/schemas/AdminPreferences' Admin: type: object properties: id: type: integer format: int32 minimum: 1 status: type: integer enum: - 0 - 1 description: | status: * `0` user is disabled, login is not allowed * `1` user is enabled username: type: string description: username is unique description: type: string description: 'optional description, for example the admin full name' password: type: string format: password description: Admin password. For security reasons this field is omitted when you search/get admins email: type: string format: email permissions: type: array items: $ref: '#/components/schemas/AdminPermissions' filters: $ref: '#/components/schemas/AdminFilters' additional_info: type: string description: Free form text field groups: type: array items: $ref: '#/components/schemas/AdminGroupMapping' description: 'Groups automatically selected for new users created by this admin. The admin will still be able to choose different groups. These settings are only used for this admin UI and they will be ignored in REST API/hooks.' created_at: type: integer format: int64 description: 'creation time as unix timestamp in milliseconds. It will be 0 for admins created before v2.2.0' updated_at: type: integer format: int64 description: last update time as unix timestamp in milliseconds last_login: type: integer format: int64 description: Last user login as unix timestamp in milliseconds. It is saved at most once every 10 minutes role: type: string description: 'If set the admin can only administer users with the same role. Role admins cannot have the "*" permission' AdminProfile: type: object properties: email: type: string format: email description: type: string allow_api_key_auth: type: boolean description: 'If enabled, you can impersonate this admin, in REST API, using an API key. If disabled admin credentials are required for impersonation' UserProfile: type: object properties: email: type: string format: email description: type: string allow_api_key_auth: type: boolean description: 'If enabled, you can impersonate this user, in REST API, using an API key. If disabled user credentials are required for impersonation' public_keys: type: array items: type: string example: ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUWwDwEWhTbF0MqAsp/oXK1HR2cElhM8oo1uVmL3ZeDKDiTm4ljMr92wfTgIGDqIoxmVqgYIkAOAhuykAVWBzc= user@host description: Public keys in OpenSSH format APIKey: type: object properties: id: type: string description: unique key identifier name: type: string description: User friendly key name key: type: string format: password description: We store the hash of the key. This is just like a password. For security reasons this field is omitted when you search/get API keys scope: $ref: '#/components/schemas/APIKeyScope' created_at: type: integer format: int64 description: creation time as unix timestamp in milliseconds updated_at: type: integer format: int64 description: last update time as unix timestamp in milliseconds last_use_at: type: integer format: int64 description: last use time as unix timestamp in milliseconds. It is saved at most once every 10 minutes expires_at: type: integer format: int64 description: expiration time as unix timestamp in milliseconds description: type: string description: optional description user: type: string description: username associated with this API key. If empty and the scope is "user scope" the key can impersonate any user admin: type: string description: admin associated with this API key. If empty and the scope is "admin scope" the key can impersonate any admin QuotaUsage: type: object properties: used_quota_size: type: integer format: int64 used_quota_files: type: integer format: int32 TransferQuotaUsage: type: object properties: used_upload_data_transfer: type: integer format: int64 description: 'The value must be specified as bytes' used_download_data_transfer: type: integer format: int64 description: 'The value must be specified as bytes' Transfer: type: object properties: operation_type: type: string enum: - upload - download description: | Operations: * `upload` * `download` path: type: string description: file path for the upload/download start_time: type: integer format: int64 description: start time as unix timestamp in milliseconds size: type: integer format: int64 description: bytes transferred ConnectionStatus: type: object properties: username: type: string description: connected username connection_id: type: string description: unique connection identifier client_version: type: string description: client version remote_address: type: string description: Remote address for the connected client connection_time: type: integer format: int64 description: connection time as unix timestamp in milliseconds command: type: string description: Last SSH/FTP command or WebDAV method last_activity: type: integer format: int64 description: last client activity as unix timestamp in milliseconds protocol: type: string enum: - SFTP - SCP - SSH - FTP - DAV active_transfers: type: array items: $ref: '#/components/schemas/Transfer' node: type: string description: 'Node identifier, omitted for single node installations' FolderRetention: type: object properties: path: type: string description: 'virtual directory path as seen by users, if no other specific retention is defined, the retention applies for sub directories too. For example if retention is defined for the paths "/" and "/sub" then the retention for "/" is applied for any file outside the "/sub" directory' example: '/' retention: type: integer description: retention time in hours. All the files with a modification time older than the defined value will be deleted. 0 means exclude this path example: 24 delete_empty_dirs: type: boolean description: if enabled, empty directories will be deleted RetentionCheck: type: object properties: username: type: string description: username to which the retention check refers folders: type: array items: $ref: '#/components/schemas/FolderRetention' start_time: type: integer format: int64 description: check start time as unix timestamp in milliseconds QuotaScan: type: object properties: username: type: string description: username to which the quota scan refers start_time: type: integer format: int64 description: scan start time as unix timestamp in milliseconds FolderQuotaScan: type: object properties: name: type: string description: folder name to which the quota scan refers start_time: type: integer format: int64 description: scan start time as unix timestamp in milliseconds DefenderEntry: type: object properties: id: type: string ip: type: string score: type: integer description: the score increases whenever a violation is detected, such as an attempt to log in using an incorrect password or invalid username. If the score exceeds the configured threshold, the IP is banned. Omitted for banned IPs ban_time: type: string format: date-time description: date time until the IP is banned. For already banned hosts, the ban time is increased each time a new violation is detected. Omitted if the IP is not banned SSHHostKey: type: object properties: path: type: string fingerprint: type: string algorithms: type: array items: type: string SSHBinding: type: object properties: address: type: string description: TCP address the server listen on port: type: integer description: the port used for serving requests apply_proxy_config: type: boolean description: 'apply the proxy configuration, if any' WebDAVBinding: type: object properties: address: type: string description: TCP address the server listen on port: type: integer description: the port used for serving requests enable_https: type: boolean min_tls_version: $ref: '#/components/schemas/TLSVersions' client_auth_type: type: integer description: 1 means that client certificate authentication is required in addition to HTTP basic authentication tls_cipher_suites: type: array items: type: string description: 'List of supported cipher suites for TLS version 1.2. If empty a default list of secure cipher suites is used, with a preference order based on hardware performance' prefix: type: string description: 'Prefix for WebDAV resources, if empty WebDAV resources will be available at the `/` URI' proxy_allowed: type: array items: type: string description: 'List of IP addresses and IP ranges allowed to set proxy headers' PassiveIPOverride: type: object properties: networks: type: array items: type: string ip: type: string FTPDBinding: type: object properties: address: type: string description: TCP address the server listen on port: type: integer description: the port used for serving requests apply_proxy_config: type: boolean description: 'apply the proxy configuration, if any' tls_mode: type: integer enum: - 0 - 1 - 2 description: | TLS mode: * `0` - clear or explicit TLS * `1` - explicit TLS required * `2` - implicit TLS min_tls_version: $ref: '#/components/schemas/TLSVersions' force_passive_ip: type: string description: External IP address for passive connections passive_ip_overrides: type: array items: $ref: '#/components/schemas/PassiveIPOverride' client_auth_type: type: integer description: 1 means that client certificate authentication is required in addition to FTP authentication tls_cipher_suites: type: array items: type: string description: 'List of supported cipher suites for TLS version 1.2. If empty a default list of secure cipher suites is used, with a preference order based on hardware performance' passive_connections_security: type: integer enum: - 0 - 1 description: | Active connections security: * `0` - require matching peer IP addresses of control and data connection * `1` - disable any checks active_connections_security: type: integer enum: - 0 - 1 description: | Active connections security: * `0` - require matching peer IP addresses of control and data connection * `1` - disable any checks ignore_ascii_transfer_type: type: integer enum: - 0 - 1 description: | Ignore client requests to perform ASCII translations: * `0` - ASCII translations are enabled * `1` - ASCII translations are silently ignored debug: type: boolean description: 'If enabled any FTP command will be logged' SSHServiceStatus: type: object properties: is_active: type: boolean bindings: type: array items: $ref: '#/components/schemas/SSHBinding' nullable: true host_keys: type: array items: $ref: '#/components/schemas/SSHHostKey' nullable: true ssh_commands: type: array items: type: string authentications: type: array items: $ref: '#/components/schemas/SSHAuthentications' public_key_algorithms: type: array items: type: string macs: type: array items: type: string kex_algorithms: type: array items: type: string ciphers: type: array items: type: string FTPPassivePortRange: type: object properties: start: type: integer end: type: integer FTPServiceStatus: type: object properties: is_active: type: boolean bindings: type: array items: $ref: '#/components/schemas/FTPDBinding' nullable: true passive_port_range: $ref: '#/components/schemas/FTPPassivePortRange' WebDAVServiceStatus: type: object properties: is_active: type: boolean bindings: type: array items: $ref: '#/components/schemas/WebDAVBinding' nullable: true DataProviderStatus: type: object properties: is_active: type: boolean driver: type: string error: type: string MFAStatus: type: object properties: is_active: type: boolean totp_configs: type: array items: $ref: '#/components/schemas/TOTPConfig' ServicesStatus: type: object properties: ssh: $ref: '#/components/schemas/SSHServiceStatus' ftp: $ref: '#/components/schemas/FTPServiceStatus' webdav: $ref: '#/components/schemas/WebDAVServiceStatus' data_provider: $ref: '#/components/schemas/DataProviderStatus' defender: type: object properties: is_active: type: boolean mfa: $ref: '#/components/schemas/MFAStatus' allow_list: type: object properties: is_active: type: boolean rate_limiters: type: object properties: is_active: type: boolean protocols: type: array items: type: string example: SSH Share: type: object properties: id: type: string description: auto-generated unique share identifier name: type: string description: type: string description: optional description scope: $ref: '#/components/schemas/ShareScope' paths: type: array items: type: string description: 'paths to files or directories, for share scope write this array must contain exactly one directory. Paths will not be validated on save so you can also create them after creating the share' example: - '/dir1' - '/dir2/file.txt' - '/dir3/subdir' username: type: string created_at: type: integer format: int64 description: 'creation time as unix timestamp in milliseconds' updated_at: type: integer format: int64 description: 'last update time as unix timestamp in milliseconds' last_use_at: type: integer format: int64 description: last use time as unix timestamp in milliseconds expires_at: type: integer format: int64 description: 'optional share expiration, as unix timestamp in milliseconds. 0 means no expiration' password: type: string description: 'optional password to protect the share. The special value "[**redacted**]" means that a password has been set, you can use this value if you want to preserve the current password when you update a share' max_tokens: type: integer description: 'maximum allowed access tokens. 0 means no limit' used_tokens: type: integer allow_from: type: array items: type: string description: 'Limit the share availability to these IP/Mask. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32". An empty list means no restrictions' example: - 192.0.2.0/24 - '2001:db8::/32' GroupUserSettings: type: object properties: home_dir: type: string max_sessions: type: integer format: int32 quota_size: type: integer format: int64 quota_files: type: integer format: int32 permissions: type: object additionalProperties: type: array items: $ref: '#/components/schemas/Permission' minItems: 1 minProperties: 1 description: 'hash map with directory as key and an array of permissions as value. Directories must be absolute paths, permissions for root directory ("/") are required' example: /: - '*' /somedir: - list - download upload_bandwidth: type: integer description: 'Maximum upload bandwidth as KB/s' download_bandwidth: type: integer description: 'Maximum download bandwidth as KB/s' upload_data_transfer: type: integer description: 'Maximum data transfer allowed for uploads as MB' download_data_transfer: type: integer description: 'Maximum data transfer allowed for downloads as MB' total_data_transfer: type: integer description: 'Maximum total data transfer as MB' expires_in: type: integer description: 'Account expiration in number of days from creation. 0 means no expiration' filters: $ref: '#/components/schemas/BaseUserFilters' filesystem: $ref: '#/components/schemas/FilesystemConfig' Role: type: object properties: id: type: integer format: int32 minimum: 1 name: type: string description: name is unique description: type: string description: 'optional description' created_at: type: integer format: int64 description: creation time as unix timestamp in milliseconds updated_at: type: integer format: int64 description: last update time as unix timestamp in milliseconds users: type: array items: type: string description: list of usernames associated with this group admins: type: array items: type: string description: list of admins usernames associated with this group Group: type: object properties: id: type: integer format: int32 minimum: 1 name: type: string description: name is unique description: type: string description: 'optional description' created_at: type: integer format: int64 description: creation time as unix timestamp in milliseconds updated_at: type: integer format: int64 description: last update time as unix timestamp in milliseconds user_settings: $ref: '#/components/schemas/GroupUserSettings' virtual_folders: type: array items: $ref: '#/components/schemas/VirtualFolder' description: mapping between virtual SFTPGo paths and folders users: type: array items: type: string description: list of usernames associated with this group admins: type: array items: type: string description: list of admins usernames associated with this group GroupMapping: type: object properties: name: type: string description: group name type: enum: - 1 - 2 - 3 description: | Group type: * `1` - Primary group * `2` - Secondary group * `3` - Membership only, no settings are inherited from this group type AdminGroupMappingOptions: type: object properties: add_to_users_as: enum: - 0 - 1 - 2 description: | Add to new users as: * `0` - the admin's group will be added as membership group for new users * `1` - the admin's group will be added as primary group for new users * `2` - the admin's group will be added as secondary group for new users AdminGroupMapping: type: object properties: name: type: string description: group name options: $ref: '#/components/schemas/AdminGroupMappingOptions' BackupData: type: object properties: users: type: array items: $ref: '#/components/schemas/User' folders: type: array items: $ref: '#/components/schemas/BaseVirtualFolder' groups: type: array items: $ref: '#/components/schemas/Group' admins: type: array items: $ref: '#/components/schemas/Admin' api_keys: type: array items: $ref: '#/components/schemas/APIKey' shares: type: array items: $ref: '#/components/schemas/Share' event_actions: type: array items: $ref: '#/components/schemas/EventAction' event_rules: type: array items: $ref: '#/components/schemas/EventRule' roles: type: array items: $ref: '#/components/schemas/Role' version: type: integer PwdChange: type: object properties: current_password: type: string new_password: type: string DirEntry: type: object properties: name: type: string description: name of the file (or subdirectory) described by the entry. This name is the final element of the path (the base name), not the entire path size: type: integer format: int64 description: file size, omitted for folders and non regular files mode: type: integer description: | File mode and permission bits. More details here: https://golang.org/pkg/io/fs/#FileMode. Let's see some examples: - for a directory mode&2147483648 != 0 - for a symlink mode&134217728 != 0 - for a regular file mode&2401763328 == 0 last_modified: type: string format: date-time FsEvent: type: object properties: id: type: string timestamp: type: integer format: int64 description: 'unix timestamp in nanoseconds' action: $ref: '#/components/schemas/FsEventAction' username: type: string fs_path: type: string fs_target_path: type: string virtual_path: type: string virtual_target_path: type: string ssh_cmd: type: string file_size: type: integer format: int64 elapsed: type: integer format: int64 description: elapsed time as milliseconds status: $ref: '#/components/schemas/FsEventStatus' protocol: $ref: '#/components/schemas/EventProtocols' ip: type: string session_id: type: string fs_provider: $ref: '#/components/schemas/FsProviders' bucket: type: string endpoint: type: string open_flags: type: string role: type: string instance_id: type: string ProviderEvent: type: object properties: id: type: string timestamp: type: integer format: int64 description: 'unix timestamp in nanoseconds' action: $ref: '#/components/schemas/ProviderEventAction' username: type: string ip: type: string object_type: $ref: '#/components/schemas/ProviderEventObjectType' object_name: type: string object_data: type: string format: byte description: 'base64 of the JSON serialized object with sensitive fields removed' role: type: string instance_id: type: string LogEvent: type: object properties: id: type: string timestamp: type: integer format: int64 description: 'unix timestamp in nanoseconds' event: $ref: '#/components/schemas/LogEventType' protocol: $ref: '#/components/schemas/EventProtocols' username: type: string ip: type: string message: type: string role: type: string instance_id: type: string KeyValue: type: object properties: key: type: string value: type: string RenameConfig: allOf: - $ref: '#/components/schemas/KeyValue' - type: object properties: update_modtime: type: boolean description: 'Update modification time. This setting is not recursive and only applies to storage providers that support changing modification times' HTTPPart: type: object properties: name: type: string headers: type: array items: $ref: '#/components/schemas/KeyValue' description: 'Additional headers. Content-Disposition header is automatically set. Content-Type header is automatically detect for files to attach' filepath: type: string description: 'path to the file to be sent as an attachment' body: type: string EventActionHTTPConfig: type: object properties: endpoint: type: string description: HTTP endpoint example: https://example.com username: type: string password: $ref: '#/components/schemas/Secret' headers: type: array items: $ref: '#/components/schemas/KeyValue' description: headers to add timeout: type: integer minimum: 1 maximum: 180 description: 'Ignored for multipart requests with files as attachments' skip_tls_verify: type: boolean description: 'if enabled the HTTP client accepts any TLS certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.' method: type: string enum: - GET - POST - PUT - DELETE query_parameters: type: array items: $ref: '#/components/schemas/KeyValue' body: type: string description: HTTP POST/PUT body parts: type: array items: $ref: '#/components/schemas/HTTPPart' description: 'Multipart requests allow to combine one or more sets of data into a single body. For each part, you can set a file path or a body as text. Placeholders are supported in file path, body, header values.' EventActionCommandConfig: type: object properties: cmd: type: string description: absolute path to the command to execute args: type: array items: type: string description: 'command line arguments' timeout: type: integer minimum: 1 maximum: 120 env_vars: type: array items: $ref: '#/components/schemas/KeyValue' EventActionEmailConfig: type: object properties: recipients: type: array items: type: string bcc: type: array items: type: string subject: type: string body: type: string content_type: type: integer enum: - 0 - 1 description: | Content type: * `0` text/plain * `1` text/html attachments: type: array items: type: string description: 'list of file paths to attach. The total size is limited to 10 MB' EventActionDataRetentionConfig: type: object properties: folders: type: array items: $ref: '#/components/schemas/FolderRetention' EventActionFsCompress: type: object properties: name: type: string description: 'Full path to the (zip) archive to create. The parent dir must exist' paths: type: array items: type: string description: 'paths to add the archive' EventActionFilesystemConfig: type: object properties: type: $ref: '#/components/schemas/FilesystemActionTypes' renames: type: array items: $ref: '#/components/schemas/RenameConfig' mkdirs: type: array items: type: string deletes: type: array items: type: string exist: type: array items: type: string copy: type: array items: $ref: '#/components/schemas/KeyValue' compress: $ref: '#/components/schemas/EventActionFsCompress' EventActionPasswordExpiration: type: object properties: threshold: type: integer description: 'An email notification will be generated for users whose password expires in a number of days less than or equal to this threshold' EventActionUserInactivity: type: object properties: disable_threshold: type: integer description: 'Inactivity threshold, in days, before disabling the account' delete_threshold: type: integer description: 'Inactivity threshold, in days, before deleting the account' EventActionIDPAccountCheck: type: object properties: mode: type: integer enum: - 0 - 1 description: | Account check mode: * `0` Create or update the account * `1` Create the account if it doesn't exist template_user: type: string description: 'SFTPGo user template in JSON format' template_admin: type: string description: 'SFTPGo admin template in JSON format' BaseEventActionOptions: type: object properties: http_config: $ref: '#/components/schemas/EventActionHTTPConfig' cmd_config: $ref: '#/components/schemas/EventActionCommandConfig' email_config: $ref: '#/components/schemas/EventActionEmailConfig' retention_config: $ref: '#/components/schemas/EventActionDataRetentionConfig' fs_config: $ref: '#/components/schemas/EventActionFilesystemConfig' pwd_expiration_config: $ref: '#/components/schemas/EventActionPasswordExpiration' user_inactivity_config: $ref: '#/components/schemas/EventActionUserInactivity' idp_config: $ref: '#/components/schemas/EventActionIDPAccountCheck' BaseEventAction: type: object properties: id: type: integer format: int32 minimum: 1 name: type: string description: unique name description: type: string description: optional description type: $ref: '#/components/schemas/EventActionTypes' options: $ref: '#/components/schemas/BaseEventActionOptions' rules: type: array items: type: string description: list of event rules names associated with this action EventActionOptions: type: object properties: is_failure_action: type: boolean stop_on_failure: type: boolean execute_sync: type: boolean EventAction: allOf: - $ref: '#/components/schemas/BaseEventAction' - type: object properties: order: type: integer description: execution order relation_options: $ref: '#/components/schemas/EventActionOptions' EventActionMinimal: type: object properties: name: type: string order: type: integer description: execution order relation_options: $ref: '#/components/schemas/EventActionOptions' ConditionPattern: type: object properties: pattern: type: string inverse_match: type: boolean ConditionOptions: type: object properties: names: type: array items: $ref: '#/components/schemas/ConditionPattern' group_names: type: array items: $ref: '#/components/schemas/ConditionPattern' role_names: type: array items: $ref: '#/components/schemas/ConditionPattern' fs_paths: type: array items: $ref: '#/components/schemas/ConditionPattern' protocols: type: array items: type: string enum: - SFTP - SCP - SSH - FTP - DAV - HTTP - HTTPShare - OIDC provider_objects: type: array items: type: string enum: - user - group - admin - api_key - share - event_action - event_rule min_size: type: integer format: int64 max_size: type: integer format: int64 event_statuses: type: array items: type: integer enum: - 1 - 2 - 3 description: | Event status: - `1` OK - `2` Failed - `3` Quota exceeded concurrent_execution: type: boolean description: allow concurrent execution from multiple nodes Schedule: type: object properties: hour: type: string day_of_week: type: string day_of_month: type: string month: type: string EventConditions: type: object properties: fs_events: type: array items: type: string enum: - upload - download - delete - rename - mkdir - rmdir - copy - ssh_cmd - pre-upload - pre-download - pre-delete - first-upload - first-download provider_events: type: array items: type: string enum: - add - update - delete schedules: type: array items: $ref: '#/components/schemas/Schedule' idp_login_event: type: integer enum: - 0 - 1 - 2 description: | IDP login events: - `0` any login event - `1` user login event - `2` admin login event options: $ref: '#/components/schemas/ConditionOptions' BaseEventRule: type: object properties: id: type: integer format: int32 minimum: 1 name: type: string description: unique name status: type: integer enum: - 0 - 1 description: | status: * `0` disabled * `1` enabled description: type: string description: optional description created_at: type: integer format: int64 description: creation time as unix timestamp in milliseconds updated_at: type: integer format: int64 description: last update time as unix timestamp in millisecond trigger: $ref: '#/components/schemas/EventTriggerTypes' conditions: $ref: '#/components/schemas/EventConditions' EventRule: allOf: - $ref: '#/components/schemas/BaseEventRule' - type: object properties: actions: type: array items: $ref: '#/components/schemas/EventAction' EventRuleMinimal: allOf: - $ref: '#/components/schemas/BaseEventRule' - type: object properties: actions: type: array items: $ref: '#/components/schemas/EventActionMinimal' IPListEntry: type: object properties: ipornet: type: string description: IP address or network in CIDR format, for example `192.168.1.2/32`, `192.168.0.0/24`, `2001:db8::/32` description: type: string description: optional description type: $ref: '#/components/schemas/IPListType' mode: $ref: '#/components/schemas/IPListMode' protocols: type: integer description: Defines the protocol the entry applies to. `0` means all the supported protocols, 1 SSH, 2 FTP, 4 WebDAV, 8 HTTP. Protocols can be combined, for example 3 means SSH and FTP created_at: type: integer format: int64 description: creation time as unix timestamp in milliseconds updated_at: type: integer format: int64 description: last update time as unix timestamp in millisecond ApiResponse: type: object properties: message: type: string description: 'message, can be empty' error: type: string description: error description if any VersionInfo: type: object properties: version: type: string build_date: type: string commit_hash: type: string features: type: array items: type: string description: 'Features for the current build. Available features are `portable`, `bolt`, `mysql`, `sqlite`, `pgsql`, `s3`, `gcs`, `azblob`, `metrics`, `unixcrypt`. If a feature is available it has a `+` prefix, otherwise a `-` prefix' Token: type: object properties: access_token: type: string expires_at: type: string format: date-time securitySchemes: BasicAuth: type: http scheme: basic BearerAuth: type: http scheme: bearer bearerFormat: JWT APIKeyAuth: type: apiKey in: header name: X-SFTPGO-API-KEY description: 'API key to use for authentication. API key authentication is intrinsically less secure than using a short lived JWT token. You should prefer API key authentication only for machine-to-machine communications in trusted environments. If no admin/user is associated to the provided key you need to add ".username" at the end of the key. For example if your API key is "6ajKLwswLccVBGpZGv596G.ySAXc8vtp9hMiwAuaLtzof" and you want to impersonate the admin with username "myadmin" you have to use "6ajKLwswLccVBGpZGv596G.ySAXc8vtp9hMiwAuaLtzof.myadmin" as API key. When using API key authentication you cannot manage API keys, update the impersonated admin, change password or public keys for the impersonated user.' ================================================ FILE: openapi/swagger-ui/index.css ================================================ html { box-sizing: border-box; overflow: -moz-scrollbars-vertical; overflow-y: scroll; } *, *:before, *:after { box-sizing: inherit; } body { margin: 0; background: #fafafa; } ================================================ FILE: openapi/swagger-ui/index.html ================================================ Swagger UI
================================================ FILE: openapi/swagger-ui/swagger-initializer.js ================================================ window.onload = function() { // // the following lines will be replaced by docker/configurator, when it runs in a docker-container window.ui = SwaggerUIBundle({ url: "../openapi.yaml", dom_id: '#swagger-ui', deepLinking: true, presets: [ SwaggerUIBundle.presets.apis, SwaggerUIStandalonePreset ], plugins: [ SwaggerUIBundle.plugins.DownloadUrl ], layout: "StandaloneLayout" }); // }; ================================================ FILE: openapi/swagger-ui/swagger-ui-bundle.js ================================================ /*! For license information please see swagger-ui-bundle.js.LICENSE.txt */ !function webpackUniversalModuleDefinition(s,o){"object"==typeof exports&&"object"==typeof module?module.exports=o():"function"==typeof define&&define.amd?define([],o):"object"==typeof exports?exports.SwaggerUIBundle=o():s.SwaggerUIBundle=o()}(this,(()=>(()=>{var s={251:(s,o)=>{o.read=function(s,o,i,a,u){var _,w,x=8*u-a-1,C=(1<>1,L=-7,B=i?u-1:0,$=i?-1:1,U=s[o+B];for(B+=$,_=U&(1<<-L)-1,U>>=-L,L+=x;L>0;_=256*_+s[o+B],B+=$,L-=8);for(w=_&(1<<-L)-1,_>>=-L,L+=a;L>0;w=256*w+s[o+B],B+=$,L-=8);if(0===_)_=1-j;else{if(_===C)return w?NaN:1/0*(U?-1:1);w+=Math.pow(2,a),_-=j}return(U?-1:1)*w*Math.pow(2,_-a)},o.write=function(s,o,i,a,u,_){var w,x,C,j=8*_-u-1,L=(1<>1,$=23===u?Math.pow(2,-24)-Math.pow(2,-77):0,U=a?0:_-1,V=a?1:-1,z=o<0||0===o&&1/o<0?1:0;for(o=Math.abs(o),isNaN(o)||o===1/0?(x=isNaN(o)?1:0,w=L):(w=Math.floor(Math.log(o)/Math.LN2),o*(C=Math.pow(2,-w))<1&&(w--,C*=2),(o+=w+B>=1?$/C:$*Math.pow(2,1-B))*C>=2&&(w++,C/=2),w+B>=L?(x=0,w=L):w+B>=1?(x=(o*C-1)*Math.pow(2,u),w+=B):(x=o*Math.pow(2,B-1)*Math.pow(2,u),w=0));u>=8;s[i+U]=255&x,U+=V,x/=256,u-=8);for(w=w<0;s[i+U]=255&w,U+=V,w/=256,j-=8);s[i+U-V]|=128*z}},462:(s,o,i)=>{"use strict";var a=i(40975);s.exports=a},659:(s,o,i)=>{var a=i(51873),u=Object.prototype,_=u.hasOwnProperty,w=u.toString,x=a?a.toStringTag:void 0;s.exports=function getRawTag(s){var o=_.call(s,x),i=s[x];try{s[x]=void 0;var a=!0}catch(s){}var u=w.call(s);return a&&(o?s[x]=i:delete s[x]),u}},694:(s,o,i)=>{"use strict";i(91599);var a=i(37257);i(12560),s.exports=a},953:(s,o,i)=>{"use strict";s.exports=i(53375)},1733:s=>{var o=/[^\x00-\x2f\x3a-\x40\x5b-\x60\x7b-\x7f]+/g;s.exports=function asciiWords(s){return s.match(o)||[]}},1882:(s,o,i)=>{var a=i(72552),u=i(23805);s.exports=function isFunction(s){if(!u(s))return!1;var o=a(s);return"[object Function]"==o||"[object GeneratorFunction]"==o||"[object AsyncFunction]"==o||"[object Proxy]"==o}},1907:(s,o,i)=>{"use strict";var a=i(41505),u=Function.prototype,_=u.call,w=a&&u.bind.bind(_,_);s.exports=a?w:function(s){return function(){return _.apply(s,arguments)}}},2205:function(s,o,i){var a;a=void 0!==i.g?i.g:this,s.exports=function(s){if(s.CSS&&s.CSS.escape)return s.CSS.escape;var cssEscape=function(s){if(0==arguments.length)throw new TypeError("`CSS.escape` requires an argument.");for(var o,i=String(s),a=i.length,u=-1,_="",w=i.charCodeAt(0);++u=1&&o<=31||127==o||0==u&&o>=48&&o<=57||1==u&&o>=48&&o<=57&&45==w?"\\"+o.toString(16)+" ":0==u&&1==a&&45==o||!(o>=128||45==o||95==o||o>=48&&o<=57||o>=65&&o<=90||o>=97&&o<=122)?"\\"+i.charAt(u):i.charAt(u):_+="�";return _};return s.CSS||(s.CSS={}),s.CSS.escape=cssEscape,cssEscape}(a)},2209:(s,o,i)=>{"use strict";var a,u=i(9404),_=function productionTypeChecker(){invariant(!1,"ImmutablePropTypes type checking code is stripped in production.")};_.isRequired=_;var w=function getProductionTypeChecker(){return _};function getPropType(s){var o=typeof s;return Array.isArray(s)?"array":s instanceof RegExp?"object":s instanceof u.Iterable?"Immutable."+s.toSource().split(" ")[0]:o}function createChainableTypeChecker(s){function checkType(o,i,a,u,_,w){for(var x=arguments.length,C=Array(x>6?x-6:0),j=6;j>",null!=i[a]?s.apply(void 0,[i,a,u,_,w].concat(C)):o?new Error("Required "+_+" `"+w+"` was not specified in `"+u+"`."):void 0}var o=checkType.bind(null,!1);return o.isRequired=checkType.bind(null,!0),o}function createIterableSubclassTypeChecker(s,o){return function createImmutableTypeChecker(s,o){return createChainableTypeChecker((function validate(i,a,u,_,w){var x=i[a];if(!o(x)){var C=getPropType(x);return new Error("Invalid "+_+" `"+w+"` of type `"+C+"` supplied to `"+u+"`, expected `"+s+"`.")}return null}))}("Iterable."+s,(function(s){return u.Iterable.isIterable(s)&&o(s)}))}(a={listOf:w,mapOf:w,orderedMapOf:w,setOf:w,orderedSetOf:w,stackOf:w,iterableOf:w,recordOf:w,shape:w,contains:w,mapContains:w,orderedMapContains:w,list:_,map:_,orderedMap:_,set:_,orderedSet:_,stack:_,seq:_,record:_,iterable:_}).iterable.indexed=createIterableSubclassTypeChecker("Indexed",u.Iterable.isIndexed),a.iterable.keyed=createIterableSubclassTypeChecker("Keyed",u.Iterable.isKeyed),s.exports=a},2404:(s,o,i)=>{var a=i(60270);s.exports=function isEqual(s,o){return a(s,o)}},2523:s=>{s.exports=function baseFindIndex(s,o,i,a){for(var u=s.length,_=i+(a?1:-1);a?_--:++_{"use strict";var a=i(45951),u=Object.defineProperty;s.exports=function(s,o){try{u(a,s,{value:o,configurable:!0,writable:!0})}catch(i){a[s]=o}return o}},2694:(s,o,i)=>{"use strict";var a=i(6925);function emptyFunction(){}function emptyFunctionWithReset(){}emptyFunctionWithReset.resetWarningCache=emptyFunction,s.exports=function(){function shim(s,o,i,u,_,w){if(w!==a){var x=new Error("Calling PropTypes validators directly is not supported by the `prop-types` package. Use PropTypes.checkPropTypes() to call them. Read more at http://fb.me/use-check-prop-types");throw x.name="Invariant Violation",x}}function getShim(){return shim}shim.isRequired=shim;var s={array:shim,bigint:shim,bool:shim,func:shim,number:shim,object:shim,string:shim,symbol:shim,any:shim,arrayOf:getShim,element:shim,elementType:shim,instanceOf:getShim,node:shim,objectOf:getShim,oneOf:getShim,oneOfType:getShim,shape:getShim,exact:getShim,checkPropTypes:emptyFunctionWithReset,resetWarningCache:emptyFunction};return s.PropTypes=s,s}},2874:s=>{s.exports={}},2875:(s,o,i)=>{"use strict";var a=i(23045),u=i(80376);s.exports=Object.keys||function keys(s){return a(s,u)}},2955:(s,o,i)=>{"use strict";var a,u=i(65606);function _defineProperty(s,o,i){return(o=function _toPropertyKey(s){var o=function _toPrimitive(s,o){if("object"!=typeof s||null===s)return s;var i=s[Symbol.toPrimitive];if(void 0!==i){var a=i.call(s,o||"default");if("object"!=typeof a)return a;throw new TypeError("@@toPrimitive must return a primitive value.")}return("string"===o?String:Number)(s)}(s,"string");return"symbol"==typeof o?o:String(o)}(o))in s?Object.defineProperty(s,o,{value:i,enumerable:!0,configurable:!0,writable:!0}):s[o]=i,s}var _=i(86238),w=Symbol("lastResolve"),x=Symbol("lastReject"),C=Symbol("error"),j=Symbol("ended"),L=Symbol("lastPromise"),B=Symbol("handlePromise"),$=Symbol("stream");function createIterResult(s,o){return{value:s,done:o}}function readAndResolve(s){var o=s[w];if(null!==o){var i=s[$].read();null!==i&&(s[L]=null,s[w]=null,s[x]=null,o(createIterResult(i,!1)))}}function onReadable(s){u.nextTick(readAndResolve,s)}var U=Object.getPrototypeOf((function(){})),V=Object.setPrototypeOf((_defineProperty(a={get stream(){return this[$]},next:function next(){var s=this,o=this[C];if(null!==o)return Promise.reject(o);if(this[j])return Promise.resolve(createIterResult(void 0,!0));if(this[$].destroyed)return new Promise((function(o,i){u.nextTick((function(){s[C]?i(s[C]):o(createIterResult(void 0,!0))}))}));var i,a=this[L];if(a)i=new Promise(function wrapForNext(s,o){return function(i,a){s.then((function(){o[j]?i(createIterResult(void 0,!0)):o[B](i,a)}),a)}}(a,this));else{var _=this[$].read();if(null!==_)return Promise.resolve(createIterResult(_,!1));i=new Promise(this[B])}return this[L]=i,i}},Symbol.asyncIterator,(function(){return this})),_defineProperty(a,"return",(function _return(){var s=this;return new Promise((function(o,i){s[$].destroy(null,(function(s){s?i(s):o(createIterResult(void 0,!0))}))}))})),a),U);s.exports=function createReadableStreamAsyncIterator(s){var o,i=Object.create(V,(_defineProperty(o={},$,{value:s,writable:!0}),_defineProperty(o,w,{value:null,writable:!0}),_defineProperty(o,x,{value:null,writable:!0}),_defineProperty(o,C,{value:null,writable:!0}),_defineProperty(o,j,{value:s._readableState.endEmitted,writable:!0}),_defineProperty(o,B,{value:function value(s,o){var a=i[$].read();a?(i[L]=null,i[w]=null,i[x]=null,s(createIterResult(a,!1))):(i[w]=s,i[x]=o)},writable:!0}),o));return i[L]=null,_(s,(function(s){if(s&&"ERR_STREAM_PREMATURE_CLOSE"!==s.code){var o=i[x];return null!==o&&(i[L]=null,i[w]=null,i[x]=null,o(s)),void(i[C]=s)}var a=i[w];null!==a&&(i[L]=null,i[w]=null,i[x]=null,a(createIterResult(void 0,!0))),i[j]=!0})),s.on("readable",onReadable.bind(null,i)),i}},3110:(s,o,i)=>{const a=i(5187),u=i(85015),_=i(98023),w=i(53812),x=i(23805),C=i(85105),j=i(86804);class Namespace{constructor(s){this.elementMap={},this.elementDetection=[],this.Element=j.Element,this.KeyValuePair=j.KeyValuePair,s&&s.noDefault||this.useDefault(),this._attributeElementKeys=[],this._attributeElementArrayKeys=[]}use(s){return s.namespace&&s.namespace({base:this}),s.load&&s.load({base:this}),this}useDefault(){return this.register("null",j.NullElement).register("string",j.StringElement).register("number",j.NumberElement).register("boolean",j.BooleanElement).register("array",j.ArrayElement).register("object",j.ObjectElement).register("member",j.MemberElement).register("ref",j.RefElement).register("link",j.LinkElement),this.detect(a,j.NullElement,!1).detect(u,j.StringElement,!1).detect(_,j.NumberElement,!1).detect(w,j.BooleanElement,!1).detect(Array.isArray,j.ArrayElement,!1).detect(x,j.ObjectElement,!1),this}register(s,o){return this._elements=void 0,this.elementMap[s]=o,this}unregister(s){return this._elements=void 0,delete this.elementMap[s],this}detect(s,o,i){return void 0===i||i?this.elementDetection.unshift([s,o]):this.elementDetection.push([s,o]),this}toElement(s){if(s instanceof this.Element)return s;let o;for(let i=0;i{const o=s[0].toUpperCase()+s.substr(1);this._elements[o]=this.elementMap[s]}))),this._elements}get serialiser(){return new C(this)}}C.prototype.Namespace=Namespace,s.exports=Namespace},3121:(s,o,i)=>{"use strict";var a=i(65482),u=Math.min;s.exports=function(s){var o=a(s);return o>0?u(o,9007199254740991):0}},3209:(s,o,i)=>{var a=i(91596),u=i(53320),_=i(36306),w="__lodash_placeholder__",x=128,C=Math.min;s.exports=function mergeData(s,o){var i=s[1],j=o[1],L=i|j,B=L<131,$=j==x&&8==i||j==x&&256==i&&s[7].length<=o[8]||384==j&&o[7].length<=o[8]&&8==i;if(!B&&!$)return s;1&j&&(s[2]=o[2],L|=1&i?0:4);var U=o[3];if(U){var V=s[3];s[3]=V?a(V,U,o[4]):U,s[4]=V?_(s[3],w):o[4]}return(U=o[5])&&(V=s[5],s[5]=V?u(V,U,o[6]):U,s[6]=V?_(s[5],w):o[6]),(U=o[7])&&(s[7]=U),j&x&&(s[8]=null==s[8]?o[8]:C(s[8],o[8])),null==s[9]&&(s[9]=o[9]),s[0]=o[0],s[1]=L,s}},3650:(s,o,i)=>{var a=i(74335)(Object.keys,Object);s.exports=a},3656:(s,o,i)=>{s=i.nmd(s);var a=i(9325),u=i(89935),_=o&&!o.nodeType&&o,w=_&&s&&!s.nodeType&&s,x=w&&w.exports===_?a.Buffer:void 0,C=(x?x.isBuffer:void 0)||u;s.exports=C},4509:(s,o,i)=>{var a=i(12651);s.exports=function mapCacheHas(s){return a(this,s).has(s)}},4640:s=>{"use strict";var o=String;s.exports=function(s){try{return o(s)}catch(s){return"Object"}}},4664:(s,o,i)=>{var a=i(79770),u=i(63345),_=Object.prototype.propertyIsEnumerable,w=Object.getOwnPropertySymbols,x=w?function(s){return null==s?[]:(s=Object(s),a(w(s),(function(o){return _.call(s,o)})))}:u;s.exports=x},4901:(s,o,i)=>{var a=i(72552),u=i(30294),_=i(40346),w={};w["[object Float32Array]"]=w["[object Float64Array]"]=w["[object Int8Array]"]=w["[object Int16Array]"]=w["[object Int32Array]"]=w["[object Uint8Array]"]=w["[object Uint8ClampedArray]"]=w["[object Uint16Array]"]=w["[object Uint32Array]"]=!0,w["[object Arguments]"]=w["[object Array]"]=w["[object ArrayBuffer]"]=w["[object Boolean]"]=w["[object DataView]"]=w["[object Date]"]=w["[object Error]"]=w["[object Function]"]=w["[object Map]"]=w["[object Number]"]=w["[object Object]"]=w["[object RegExp]"]=w["[object Set]"]=w["[object String]"]=w["[object WeakMap]"]=!1,s.exports=function baseIsTypedArray(s){return _(s)&&u(s.length)&&!!w[a(s)]}},4993:(s,o,i)=>{"use strict";var a=i(16946),u=i(74239);s.exports=function(s){return a(u(s))}},5187:s=>{s.exports=function isNull(s){return null===s}},5419:s=>{s.exports=function(s,o,i,a){var u=new Blob(void 0!==a?[a,s]:[s],{type:i||"application/octet-stream"});if(void 0!==window.navigator.msSaveBlob)window.navigator.msSaveBlob(u,o);else{var _=window.URL&&window.URL.createObjectURL?window.URL.createObjectURL(u):window.webkitURL.createObjectURL(u),w=document.createElement("a");w.style.display="none",w.href=_,w.setAttribute("download",o),void 0===w.download&&w.setAttribute("target","_blank"),document.body.appendChild(w),w.click(),setTimeout((function(){document.body.removeChild(w),window.URL.revokeObjectURL(_)}),200)}}},5556:(s,o,i)=>{s.exports=i(2694)()},5861:(s,o,i)=>{var a=i(55580),u=i(68223),_=i(32804),w=i(76545),x=i(28303),C=i(72552),j=i(47473),L="[object Map]",B="[object Promise]",$="[object Set]",U="[object WeakMap]",V="[object DataView]",z=j(a),Y=j(u),Z=j(_),ee=j(w),ie=j(x),ae=C;(a&&ae(new a(new ArrayBuffer(1)))!=V||u&&ae(new u)!=L||_&&ae(_.resolve())!=B||w&&ae(new w)!=$||x&&ae(new x)!=U)&&(ae=function(s){var o=C(s),i="[object Object]"==o?s.constructor:void 0,a=i?j(i):"";if(a)switch(a){case z:return V;case Y:return L;case Z:return B;case ee:return $;case ie:return U}return o}),s.exports=ae},6048:s=>{s.exports=function negate(s){if("function"!=typeof s)throw new TypeError("Expected a function");return function(){var o=arguments;switch(o.length){case 0:return!s.call(this);case 1:return!s.call(this,o[0]);case 2:return!s.call(this,o[0],o[1]);case 3:return!s.call(this,o[0],o[1],o[2])}return!s.apply(this,o)}}},6188:s=>{"use strict";s.exports=Math.max},6205:s=>{s.exports={ROOT:0,GROUP:1,POSITION:2,SET:3,RANGE:4,REPETITION:5,REFERENCE:6,CHAR:7}},6233:(s,o,i)=>{const a=i(6048),u=i(10316),_=i(92340);class ArrayElement extends u{constructor(s,o,i){super(s||[],o,i),this.element="array"}primitive(){return"array"}get(s){return this.content[s]}getValue(s){const o=this.get(s);if(o)return o.toValue()}getIndex(s){return this.content[s]}set(s,o){return this.content[s]=this.refract(o),this}remove(s){const o=this.content.splice(s,1);return o.length?o[0]:null}map(s,o){return this.content.map(s,o)}flatMap(s,o){return this.map(s,o).reduce(((s,o)=>s.concat(o)),[])}compactMap(s,o){const i=[];return this.forEach((a=>{const u=s.bind(o)(a);u&&i.push(u)})),i}filter(s,o){return new _(this.content.filter(s,o))}reject(s,o){return this.filter(a(s),o)}reduce(s,o){let i,a;void 0!==o?(i=0,a=this.refract(o)):(i=1,a="object"===this.primitive()?this.first.value:this.first);for(let o=i;o{s.bind(o)(i,this.refract(a))}))}shift(){return this.content.shift()}unshift(s){this.content.unshift(this.refract(s))}push(s){return this.content.push(this.refract(s)),this}add(s){this.push(s)}findElements(s,o){const i=o||{},a=!!i.recursive,u=void 0===i.results?[]:i.results;return this.forEach(((o,i,_)=>{a&&void 0!==o.findElements&&o.findElements(s,{results:u,recursive:a}),s(o,i,_)&&u.push(o)})),u}find(s){return new _(this.findElements(s,{recursive:!0}))}findByElement(s){return this.find((o=>o.element===s))}findByClass(s){return this.find((o=>o.classes.includes(s)))}getById(s){return this.find((o=>o.id.toValue()===s)).first}includes(s){return this.content.some((o=>o.equals(s)))}contains(s){return this.includes(s)}empty(){return new this.constructor([])}"fantasy-land/empty"(){return this.empty()}concat(s){return new this.constructor(this.content.concat(s.content))}"fantasy-land/concat"(s){return this.concat(s)}"fantasy-land/map"(s){return new this.constructor(this.map(s))}"fantasy-land/chain"(s){return this.map((o=>s(o)),this).reduce(((s,o)=>s.concat(o)),this.empty())}"fantasy-land/filter"(s){return new this.constructor(this.content.filter(s))}"fantasy-land/reduce"(s,o){return this.content.reduce(s,o)}get length(){return this.content.length}get isEmpty(){return 0===this.content.length}get first(){return this.getIndex(0)}get second(){return this.getIndex(1)}get last(){return this.getIndex(this.length-1)}}ArrayElement.empty=function empty(){return new this},ArrayElement["fantasy-land/empty"]=ArrayElement.empty,"undefined"!=typeof Symbol&&(ArrayElement.prototype[Symbol.iterator]=function symbol(){return this.content[Symbol.iterator]()}),s.exports=ArrayElement},6499:(s,o,i)=>{"use strict";var a=i(1907),u=0,_=Math.random(),w=a(1..toString);s.exports=function(s){return"Symbol("+(void 0===s?"":s)+")_"+w(++u+_,36)}},6549:s=>{"use strict";s.exports=Object.getOwnPropertyDescriptor},6925:s=>{"use strict";s.exports="SECRET_DO_NOT_PASS_THIS_OR_YOU_WILL_BE_FIRED"},7057:(s,o,i)=>{"use strict";var a=i(11470).charAt,u=i(90160),_=i(64932),w=i(60183),x=i(59550),C="String Iterator",j=_.set,L=_.getterFor(C);w(String,"String",(function(s){j(this,{type:C,string:u(s),index:0})}),(function next(){var s,o=L(this),i=o.string,u=o.index;return u>=i.length?x(void 0,!0):(s=a(i,u),o.index+=s.length,x(s,!1))}))},7176:(s,o,i)=>{"use strict";var a,u=i(73126),_=i(75795);try{a=[].__proto__===Array.prototype}catch(s){if(!s||"object"!=typeof s||!("code"in s)||"ERR_PROTO_ACCESS"!==s.code)throw s}var w=!!a&&_&&_(Object.prototype,"__proto__"),x=Object,C=x.getPrototypeOf;s.exports=w&&"function"==typeof w.get?u([w.get]):"function"==typeof C&&function getDunder(s){return C(null==s?s:x(s))}},7309:(s,o,i)=>{var a=i(62006)(i(24713));s.exports=a},7376:s=>{"use strict";s.exports=!0},7463:(s,o,i)=>{"use strict";var a=i(98828),u=i(62250),_=/#|\.prototype\./,isForced=function(s,o){var i=x[w(s)];return i===j||i!==C&&(u(o)?a(o):!!o)},w=isForced.normalize=function(s){return String(s).replace(_,".").toLowerCase()},x=isForced.data={},C=isForced.NATIVE="N",j=isForced.POLYFILL="P";s.exports=isForced},7666:(s,o,i)=>{var a=i(84851),u=i(953);function _extends(){var o;return s.exports=_extends=a?u(o=a).call(o):function(s){for(var o=1;o{const a=i(6205);o.wordBoundary=()=>({type:a.POSITION,value:"b"}),o.nonWordBoundary=()=>({type:a.POSITION,value:"B"}),o.begin=()=>({type:a.POSITION,value:"^"}),o.end=()=>({type:a.POSITION,value:"$"})},8068:s=>{"use strict";var o=(()=>{var s=Object.defineProperty,o=Object.getOwnPropertyDescriptor,i=Object.getOwnPropertyNames,a=Object.getOwnPropertySymbols,u=Object.prototype.hasOwnProperty,_=Object.prototype.propertyIsEnumerable,__defNormalProp=(o,i,a)=>i in o?s(o,i,{enumerable:!0,configurable:!0,writable:!0,value:a}):o[i]=a,__spreadValues=(s,o)=>{for(var i in o||(o={}))u.call(o,i)&&__defNormalProp(s,i,o[i]);if(a)for(var i of a(o))_.call(o,i)&&__defNormalProp(s,i,o[i]);return s},__publicField=(s,o,i)=>__defNormalProp(s,"symbol"!=typeof o?o+"":o,i),w={};((o,i)=>{for(var a in i)s(o,a,{get:i[a],enumerable:!0})})(w,{DEFAULT_OPTIONS:()=>C,DEFAULT_UUID_LENGTH:()=>x,default:()=>B});var x=6,C={dictionary:"alphanum",shuffle:!0,debug:!1,length:x,counter:0},j=class _ShortUniqueId{constructor(s={}){__publicField(this,"counter"),__publicField(this,"debug"),__publicField(this,"dict"),__publicField(this,"version"),__publicField(this,"dictIndex",0),__publicField(this,"dictRange",[]),__publicField(this,"lowerBound",0),__publicField(this,"upperBound",0),__publicField(this,"dictLength",0),__publicField(this,"uuidLength"),__publicField(this,"_digit_first_ascii",48),__publicField(this,"_digit_last_ascii",58),__publicField(this,"_alpha_lower_first_ascii",97),__publicField(this,"_alpha_lower_last_ascii",123),__publicField(this,"_hex_last_ascii",103),__publicField(this,"_alpha_upper_first_ascii",65),__publicField(this,"_alpha_upper_last_ascii",91),__publicField(this,"_number_dict_ranges",{digits:[this._digit_first_ascii,this._digit_last_ascii]}),__publicField(this,"_alpha_dict_ranges",{lowerCase:[this._alpha_lower_first_ascii,this._alpha_lower_last_ascii],upperCase:[this._alpha_upper_first_ascii,this._alpha_upper_last_ascii]}),__publicField(this,"_alpha_lower_dict_ranges",{lowerCase:[this._alpha_lower_first_ascii,this._alpha_lower_last_ascii]}),__publicField(this,"_alpha_upper_dict_ranges",{upperCase:[this._alpha_upper_first_ascii,this._alpha_upper_last_ascii]}),__publicField(this,"_alphanum_dict_ranges",{digits:[this._digit_first_ascii,this._digit_last_ascii],lowerCase:[this._alpha_lower_first_ascii,this._alpha_lower_last_ascii],upperCase:[this._alpha_upper_first_ascii,this._alpha_upper_last_ascii]}),__publicField(this,"_alphanum_lower_dict_ranges",{digits:[this._digit_first_ascii,this._digit_last_ascii],lowerCase:[this._alpha_lower_first_ascii,this._alpha_lower_last_ascii]}),__publicField(this,"_alphanum_upper_dict_ranges",{digits:[this._digit_first_ascii,this._digit_last_ascii],upperCase:[this._alpha_upper_first_ascii,this._alpha_upper_last_ascii]}),__publicField(this,"_hex_dict_ranges",{decDigits:[this._digit_first_ascii,this._digit_last_ascii],alphaDigits:[this._alpha_lower_first_ascii,this._hex_last_ascii]}),__publicField(this,"_dict_ranges",{_number_dict_ranges:this._number_dict_ranges,_alpha_dict_ranges:this._alpha_dict_ranges,_alpha_lower_dict_ranges:this._alpha_lower_dict_ranges,_alpha_upper_dict_ranges:this._alpha_upper_dict_ranges,_alphanum_dict_ranges:this._alphanum_dict_ranges,_alphanum_lower_dict_ranges:this._alphanum_lower_dict_ranges,_alphanum_upper_dict_ranges:this._alphanum_upper_dict_ranges,_hex_dict_ranges:this._hex_dict_ranges}),__publicField(this,"log",((...s)=>{const o=[...s];o[0]="[short-unique-id] ".concat(s[0]),!0!==this.debug||"undefined"==typeof console||null===console||console.log(...o)})),__publicField(this,"_normalizeDictionary",((s,o)=>{let i;if(s&&Array.isArray(s)&&s.length>1)i=s;else{i=[],this.dictIndex=0;const o="_".concat(s,"_dict_ranges"),a=this._dict_ranges[o];let u=0;for(const[,s]of Object.entries(a)){const[o,i]=s;u+=Math.abs(i-o)}i=new Array(u);let _=0;for(const[,s]of Object.entries(a)){this.dictRange=s,this.lowerBound=this.dictRange[0],this.upperBound=this.dictRange[1];const o=this.lowerBound<=this.upperBound,a=this.lowerBound,u=this.upperBound;if(o)for(let s=a;su;s--)i[_++]=String.fromCharCode(s),this.dictIndex=s}i.length=_}if(o){for(let s=i.length-1;s>0;s--){const o=Math.floor(Math.random()*(s+1));[i[s],i[o]]=[i[o],i[s]]}}return i})),__publicField(this,"setDictionary",((s,o)=>{this.dict=this._normalizeDictionary(s,o),this.dictLength=this.dict.length,this.setCounter(0)})),__publicField(this,"seq",(()=>this.sequentialUUID())),__publicField(this,"sequentialUUID",(()=>{const s=this.dictLength,o=this.dict;let i=this.counter;const a=[];do{const u=i%s;i=Math.trunc(i/s),a.push(o[u])}while(0!==i);const u=a.join("");return this.counter+=1,u})),__publicField(this,"rnd",((s=this.uuidLength||x)=>this.randomUUID(s))),__publicField(this,"randomUUID",((s=this.uuidLength||x)=>{if(null==s||s<1)throw new Error("Invalid UUID Length Provided");const o=new Array(s),i=this.dictLength,a=this.dict;for(let u=0;uthis.formattedUUID(s,o))),__publicField(this,"formattedUUID",((s,o)=>{const i={$r:this.randomUUID,$s:this.sequentialUUID,$t:this.stamp};return s.replace(/\$[rs]\d{0,}|\$t0|\$t[1-9]\d{1,}/g,(s=>{const a=s.slice(0,2),u=Number.parseInt(s.slice(2),10);return"$s"===a?i[a]().padStart(u,"0"):"$t"===a&&o?i[a](u,o):i[a](u)}))})),__publicField(this,"availableUUIDs",((s=this.uuidLength)=>Number.parseFloat(([...new Set(this.dict)].length**s).toFixed(0)))),__publicField(this,"_collisionCache",new Map),__publicField(this,"approxMaxBeforeCollision",((s=this.availableUUIDs(this.uuidLength))=>{const o=s,i=this._collisionCache.get(o);if(void 0!==i)return i;const a=Number.parseFloat(Math.sqrt(Math.PI/2*s).toFixed(20));return this._collisionCache.set(o,a),a})),__publicField(this,"collisionProbability",((s=this.availableUUIDs(this.uuidLength),o=this.uuidLength)=>Number.parseFloat((this.approxMaxBeforeCollision(s)/this.availableUUIDs(o)).toFixed(20)))),__publicField(this,"uniqueness",((s=this.availableUUIDs(this.uuidLength))=>{const o=Number.parseFloat((1-this.approxMaxBeforeCollision(s)/s).toFixed(20));return o>1?1:o<0?0:o})),__publicField(this,"getVersion",(()=>this.version)),__publicField(this,"stamp",((s,o)=>{const i=Math.floor(+(o||new Date)/1e3).toString(16);if("number"==typeof s&&0===s)return i;if("number"!=typeof s||s<10)throw new Error(["Param finalLength must be a number greater than or equal to 10,","or 0 if you want the raw hexadecimal timestamp"].join("\n"));const a=s-9,u=Math.round(Math.random()*(a>15?15:a)),_=this.randomUUID(a);return"".concat(_.substring(0,u)).concat(i).concat(_.substring(u)).concat(u.toString(16))})),__publicField(this,"parseStamp",((s,o)=>{if(o&&!/t0|t[1-9]\d{1,}/.test(o))throw new Error("Cannot extract date from a formated UUID with no timestamp in the format");const i=o?o.replace(/\$[rs]\d{0,}|\$t0|\$t[1-9]\d{1,}/g,(s=>{const o={$r:s=>[...Array(s)].map((()=>"r")).join(""),$s:s=>[...Array(s)].map((()=>"s")).join(""),$t:s=>[...Array(s)].map((()=>"t")).join("")},i=s.slice(0,2),a=Number.parseInt(s.slice(2),10);return o[i](a)})).replace(/^(.*?)(t{8,})(.*)$/g,((o,i,a)=>s.substring(i.length,i.length+a.length))):s;if(8===i.length)return new Date(1e3*Number.parseInt(i,16));if(i.length<10)throw new Error("Stamp length invalid");const a=Number.parseInt(i.substring(i.length-1),16);return new Date(1e3*Number.parseInt(i.substring(a,a+8),16))})),__publicField(this,"setCounter",(s=>{this.counter=s})),__publicField(this,"validate",((s,o)=>{const i=o?this._normalizeDictionary(o):this.dict;return s.split("").every((s=>i.includes(s)))}));const o=__spreadValues(__spreadValues({},C),s);this.counter=0,this.debug=!1,this.dict=[],this.version="5.3.2";const{dictionary:i,shuffle:a,length:u,counter:_}=o;this.uuidLength=u,this.setDictionary(i,a),this.setCounter(_),this.debug=o.debug,this.log(this.dict),this.log("Generator instantiated with Dictionary Size ".concat(this.dictLength," and counter set to ").concat(this.counter)),this.log=this.log.bind(this),this.setDictionary=this.setDictionary.bind(this),this.setCounter=this.setCounter.bind(this),this.seq=this.seq.bind(this),this.sequentialUUID=this.sequentialUUID.bind(this),this.rnd=this.rnd.bind(this),this.randomUUID=this.randomUUID.bind(this),this.fmt=this.fmt.bind(this),this.formattedUUID=this.formattedUUID.bind(this),this.availableUUIDs=this.availableUUIDs.bind(this),this.approxMaxBeforeCollision=this.approxMaxBeforeCollision.bind(this),this.collisionProbability=this.collisionProbability.bind(this),this.uniqueness=this.uniqueness.bind(this),this.getVersion=this.getVersion.bind(this),this.stamp=this.stamp.bind(this),this.parseStamp=this.parseStamp.bind(this)}};__publicField(j,"default",j);var L,B=j;return L=w,((a,_,w,x)=>{if(_&&"object"==typeof _||"function"==typeof _)for(let C of i(_))u.call(a,C)||C===w||s(a,C,{get:()=>_[C],enumerable:!(x=o(_,C))||x.enumerable});return a})(s({},"__esModule",{value:!0}),L)})();s.exports=o.default,"undefined"!=typeof window&&(o=o.default)},9325:(s,o,i)=>{var a=i(34840),u="object"==typeof self&&self&&self.Object===Object&&self,_=a||u||Function("return this")();s.exports=_},9404:function(s){s.exports=function(){"use strict";var s=Array.prototype.slice;function createClass(s,o){o&&(s.prototype=Object.create(o.prototype)),s.prototype.constructor=s}function Iterable(s){return isIterable(s)?s:Seq(s)}function KeyedIterable(s){return isKeyed(s)?s:KeyedSeq(s)}function IndexedIterable(s){return isIndexed(s)?s:IndexedSeq(s)}function SetIterable(s){return isIterable(s)&&!isAssociative(s)?s:SetSeq(s)}function isIterable(s){return!(!s||!s[o])}function isKeyed(s){return!(!s||!s[i])}function isIndexed(s){return!(!s||!s[a])}function isAssociative(s){return isKeyed(s)||isIndexed(s)}function isOrdered(s){return!(!s||!s[u])}createClass(KeyedIterable,Iterable),createClass(IndexedIterable,Iterable),createClass(SetIterable,Iterable),Iterable.isIterable=isIterable,Iterable.isKeyed=isKeyed,Iterable.isIndexed=isIndexed,Iterable.isAssociative=isAssociative,Iterable.isOrdered=isOrdered,Iterable.Keyed=KeyedIterable,Iterable.Indexed=IndexedIterable,Iterable.Set=SetIterable;var o="@@__IMMUTABLE_ITERABLE__@@",i="@@__IMMUTABLE_KEYED__@@",a="@@__IMMUTABLE_INDEXED__@@",u="@@__IMMUTABLE_ORDERED__@@",_="delete",w=5,x=1<>>0;if(""+i!==o||4294967295===i)return NaN;o=i}return o<0?ensureSize(s)+o:o}function returnTrue(){return!0}function wholeSlice(s,o,i){return(0===s||void 0!==i&&s<=-i)&&(void 0===o||void 0!==i&&o>=i)}function resolveBegin(s,o){return resolveIndex(s,o,0)}function resolveEnd(s,o){return resolveIndex(s,o,o)}function resolveIndex(s,o,i){return void 0===s?i:s<0?Math.max(0,o+s):void 0===o?s:Math.min(o,s)}var $=0,U=1,V=2,z="function"==typeof Symbol&&Symbol.iterator,Y="@@iterator",Z=z||Y;function Iterator(s){this.next=s}function iteratorValue(s,o,i,a){var u=0===s?o:1===s?i:[o,i];return a?a.value=u:a={value:u,done:!1},a}function iteratorDone(){return{value:void 0,done:!0}}function hasIterator(s){return!!getIteratorFn(s)}function isIterator(s){return s&&"function"==typeof s.next}function getIterator(s){var o=getIteratorFn(s);return o&&o.call(s)}function getIteratorFn(s){var o=s&&(z&&s[z]||s[Y]);if("function"==typeof o)return o}function isArrayLike(s){return s&&"number"==typeof s.length}function Seq(s){return null==s?emptySequence():isIterable(s)?s.toSeq():seqFromValue(s)}function KeyedSeq(s){return null==s?emptySequence().toKeyedSeq():isIterable(s)?isKeyed(s)?s.toSeq():s.fromEntrySeq():keyedSeqFromValue(s)}function IndexedSeq(s){return null==s?emptySequence():isIterable(s)?isKeyed(s)?s.entrySeq():s.toIndexedSeq():indexedSeqFromValue(s)}function SetSeq(s){return(null==s?emptySequence():isIterable(s)?isKeyed(s)?s.entrySeq():s:indexedSeqFromValue(s)).toSetSeq()}Iterator.prototype.toString=function(){return"[Iterator]"},Iterator.KEYS=$,Iterator.VALUES=U,Iterator.ENTRIES=V,Iterator.prototype.inspect=Iterator.prototype.toSource=function(){return this.toString()},Iterator.prototype[Z]=function(){return this},createClass(Seq,Iterable),Seq.of=function(){return Seq(arguments)},Seq.prototype.toSeq=function(){return this},Seq.prototype.toString=function(){return this.__toString("Seq {","}")},Seq.prototype.cacheResult=function(){return!this._cache&&this.__iterateUncached&&(this._cache=this.entrySeq().toArray(),this.size=this._cache.length),this},Seq.prototype.__iterate=function(s,o){return seqIterate(this,s,o,!0)},Seq.prototype.__iterator=function(s,o){return seqIterator(this,s,o,!0)},createClass(KeyedSeq,Seq),KeyedSeq.prototype.toKeyedSeq=function(){return this},createClass(IndexedSeq,Seq),IndexedSeq.of=function(){return IndexedSeq(arguments)},IndexedSeq.prototype.toIndexedSeq=function(){return this},IndexedSeq.prototype.toString=function(){return this.__toString("Seq [","]")},IndexedSeq.prototype.__iterate=function(s,o){return seqIterate(this,s,o,!1)},IndexedSeq.prototype.__iterator=function(s,o){return seqIterator(this,s,o,!1)},createClass(SetSeq,Seq),SetSeq.of=function(){return SetSeq(arguments)},SetSeq.prototype.toSetSeq=function(){return this},Seq.isSeq=isSeq,Seq.Keyed=KeyedSeq,Seq.Set=SetSeq,Seq.Indexed=IndexedSeq;var ee,ie,ae,ce="@@__IMMUTABLE_SEQ__@@";function ArraySeq(s){this._array=s,this.size=s.length}function ObjectSeq(s){var o=Object.keys(s);this._object=s,this._keys=o,this.size=o.length}function IterableSeq(s){this._iterable=s,this.size=s.length||s.size}function IteratorSeq(s){this._iterator=s,this._iteratorCache=[]}function isSeq(s){return!(!s||!s[ce])}function emptySequence(){return ee||(ee=new ArraySeq([]))}function keyedSeqFromValue(s){var o=Array.isArray(s)?new ArraySeq(s).fromEntrySeq():isIterator(s)?new IteratorSeq(s).fromEntrySeq():hasIterator(s)?new IterableSeq(s).fromEntrySeq():"object"==typeof s?new ObjectSeq(s):void 0;if(!o)throw new TypeError("Expected Array or iterable object of [k, v] entries, or keyed object: "+s);return o}function indexedSeqFromValue(s){var o=maybeIndexedSeqFromValue(s);if(!o)throw new TypeError("Expected Array or iterable object of values: "+s);return o}function seqFromValue(s){var o=maybeIndexedSeqFromValue(s)||"object"==typeof s&&new ObjectSeq(s);if(!o)throw new TypeError("Expected Array or iterable object of values, or keyed object: "+s);return o}function maybeIndexedSeqFromValue(s){return isArrayLike(s)?new ArraySeq(s):isIterator(s)?new IteratorSeq(s):hasIterator(s)?new IterableSeq(s):void 0}function seqIterate(s,o,i,a){var u=s._cache;if(u){for(var _=u.length-1,w=0;w<=_;w++){var x=u[i?_-w:w];if(!1===o(x[1],a?x[0]:w,s))return w+1}return w}return s.__iterateUncached(o,i)}function seqIterator(s,o,i,a){var u=s._cache;if(u){var _=u.length-1,w=0;return new Iterator((function(){var s=u[i?_-w:w];return w++>_?iteratorDone():iteratorValue(o,a?s[0]:w-1,s[1])}))}return s.__iteratorUncached(o,i)}function fromJS(s,o){return o?fromJSWith(o,s,"",{"":s}):fromJSDefault(s)}function fromJSWith(s,o,i,a){return Array.isArray(o)?s.call(a,i,IndexedSeq(o).map((function(i,a){return fromJSWith(s,i,a,o)}))):isPlainObj(o)?s.call(a,i,KeyedSeq(o).map((function(i,a){return fromJSWith(s,i,a,o)}))):o}function fromJSDefault(s){return Array.isArray(s)?IndexedSeq(s).map(fromJSDefault).toList():isPlainObj(s)?KeyedSeq(s).map(fromJSDefault).toMap():s}function isPlainObj(s){return s&&(s.constructor===Object||void 0===s.constructor)}function is(s,o){if(s===o||s!=s&&o!=o)return!0;if(!s||!o)return!1;if("function"==typeof s.valueOf&&"function"==typeof o.valueOf){if((s=s.valueOf())===(o=o.valueOf())||s!=s&&o!=o)return!0;if(!s||!o)return!1}return!("function"!=typeof s.equals||"function"!=typeof o.equals||!s.equals(o))}function deepEqual(s,o){if(s===o)return!0;if(!isIterable(o)||void 0!==s.size&&void 0!==o.size&&s.size!==o.size||void 0!==s.__hash&&void 0!==o.__hash&&s.__hash!==o.__hash||isKeyed(s)!==isKeyed(o)||isIndexed(s)!==isIndexed(o)||isOrdered(s)!==isOrdered(o))return!1;if(0===s.size&&0===o.size)return!0;var i=!isAssociative(s);if(isOrdered(s)){var a=s.entries();return o.every((function(s,o){var u=a.next().value;return u&&is(u[1],s)&&(i||is(u[0],o))}))&&a.next().done}var u=!1;if(void 0===s.size)if(void 0===o.size)"function"==typeof s.cacheResult&&s.cacheResult();else{u=!0;var _=s;s=o,o=_}var w=!0,x=o.__iterate((function(o,a){if(i?!s.has(o):u?!is(o,s.get(a,j)):!is(s.get(a,j),o))return w=!1,!1}));return w&&s.size===x}function Repeat(s,o){if(!(this instanceof Repeat))return new Repeat(s,o);if(this._value=s,this.size=void 0===o?1/0:Math.max(0,o),0===this.size){if(ie)return ie;ie=this}}function invariant(s,o){if(!s)throw new Error(o)}function Range(s,o,i){if(!(this instanceof Range))return new Range(s,o,i);if(invariant(0!==i,"Cannot step a Range by 0"),s=s||0,void 0===o&&(o=1/0),i=void 0===i?1:Math.abs(i),oa?iteratorDone():iteratorValue(s,u,i[o?a-u++:u++])}))},createClass(ObjectSeq,KeyedSeq),ObjectSeq.prototype.get=function(s,o){return void 0===o||this.has(s)?this._object[s]:o},ObjectSeq.prototype.has=function(s){return this._object.hasOwnProperty(s)},ObjectSeq.prototype.__iterate=function(s,o){for(var i=this._object,a=this._keys,u=a.length-1,_=0;_<=u;_++){var w=a[o?u-_:_];if(!1===s(i[w],w,this))return _+1}return _},ObjectSeq.prototype.__iterator=function(s,o){var i=this._object,a=this._keys,u=a.length-1,_=0;return new Iterator((function(){var w=a[o?u-_:_];return _++>u?iteratorDone():iteratorValue(s,w,i[w])}))},ObjectSeq.prototype[u]=!0,createClass(IterableSeq,IndexedSeq),IterableSeq.prototype.__iterateUncached=function(s,o){if(o)return this.cacheResult().__iterate(s,o);var i=getIterator(this._iterable),a=0;if(isIterator(i))for(var u;!(u=i.next()).done&&!1!==s(u.value,a++,this););return a},IterableSeq.prototype.__iteratorUncached=function(s,o){if(o)return this.cacheResult().__iterator(s,o);var i=getIterator(this._iterable);if(!isIterator(i))return new Iterator(iteratorDone);var a=0;return new Iterator((function(){var o=i.next();return o.done?o:iteratorValue(s,a++,o.value)}))},createClass(IteratorSeq,IndexedSeq),IteratorSeq.prototype.__iterateUncached=function(s,o){if(o)return this.cacheResult().__iterate(s,o);for(var i,a=this._iterator,u=this._iteratorCache,_=0;_=a.length){var o=i.next();if(o.done)return o;a[u]=o.value}return iteratorValue(s,u,a[u++])}))},createClass(Repeat,IndexedSeq),Repeat.prototype.toString=function(){return 0===this.size?"Repeat []":"Repeat [ "+this._value+" "+this.size+" times ]"},Repeat.prototype.get=function(s,o){return this.has(s)?this._value:o},Repeat.prototype.includes=function(s){return is(this._value,s)},Repeat.prototype.slice=function(s,o){var i=this.size;return wholeSlice(s,o,i)?this:new Repeat(this._value,resolveEnd(o,i)-resolveBegin(s,i))},Repeat.prototype.reverse=function(){return this},Repeat.prototype.indexOf=function(s){return is(this._value,s)?0:-1},Repeat.prototype.lastIndexOf=function(s){return is(this._value,s)?this.size:-1},Repeat.prototype.__iterate=function(s,o){for(var i=0;i=0&&o=0&&ii?iteratorDone():iteratorValue(s,_++,w)}))},Range.prototype.equals=function(s){return s instanceof Range?this._start===s._start&&this._end===s._end&&this._step===s._step:deepEqual(this,s)},createClass(Collection,Iterable),createClass(KeyedCollection,Collection),createClass(IndexedCollection,Collection),createClass(SetCollection,Collection),Collection.Keyed=KeyedCollection,Collection.Indexed=IndexedCollection,Collection.Set=SetCollection;var le="function"==typeof Math.imul&&-2===Math.imul(4294967295,2)?Math.imul:function imul(s,o){var i=65535&(s|=0),a=65535&(o|=0);return i*a+((s>>>16)*a+i*(o>>>16)<<16>>>0)|0};function smi(s){return s>>>1&1073741824|3221225471&s}function hash(s){if(!1===s||null==s)return 0;if("function"==typeof s.valueOf&&(!1===(s=s.valueOf())||null==s))return 0;if(!0===s)return 1;var o=typeof s;if("number"===o){if(s!=s||s===1/0)return 0;var i=0|s;for(i!==s&&(i^=4294967295*s);s>4294967295;)i^=s/=4294967295;return smi(i)}if("string"===o)return s.length>Se?cachedHashString(s):hashString(s);if("function"==typeof s.hashCode)return s.hashCode();if("object"===o)return hashJSObj(s);if("function"==typeof s.toString)return hashString(s.toString());throw new Error("Value type "+o+" cannot be hashed.")}function cachedHashString(s){var o=Pe[s];return void 0===o&&(o=hashString(s),xe===we&&(xe=0,Pe={}),xe++,Pe[s]=o),o}function hashString(s){for(var o=0,i=0;i0)switch(s.nodeType){case 1:return s.uniqueID;case 9:return s.documentElement&&s.documentElement.uniqueID}}var fe,ye="function"==typeof WeakMap;ye&&(fe=new WeakMap);var be=0,_e="__immutablehash__";"function"==typeof Symbol&&(_e=Symbol(_e));var Se=16,we=255,xe=0,Pe={};function assertNotInfinite(s){invariant(s!==1/0,"Cannot perform this action with an infinite size.")}function Map(s){return null==s?emptyMap():isMap(s)&&!isOrdered(s)?s:emptyMap().withMutations((function(o){var i=KeyedIterable(s);assertNotInfinite(i.size),i.forEach((function(s,i){return o.set(i,s)}))}))}function isMap(s){return!(!s||!s[Re])}createClass(Map,KeyedCollection),Map.of=function(){var o=s.call(arguments,0);return emptyMap().withMutations((function(s){for(var i=0;i=o.length)throw new Error("Missing value for key: "+o[i]);s.set(o[i],o[i+1])}}))},Map.prototype.toString=function(){return this.__toString("Map {","}")},Map.prototype.get=function(s,o){return this._root?this._root.get(0,void 0,s,o):o},Map.prototype.set=function(s,o){return updateMap(this,s,o)},Map.prototype.setIn=function(s,o){return this.updateIn(s,j,(function(){return o}))},Map.prototype.remove=function(s){return updateMap(this,s,j)},Map.prototype.deleteIn=function(s){return this.updateIn(s,(function(){return j}))},Map.prototype.update=function(s,o,i){return 1===arguments.length?s(this):this.updateIn([s],o,i)},Map.prototype.updateIn=function(s,o,i){i||(i=o,o=void 0);var a=updateInDeepMap(this,forceIterator(s),o,i);return a===j?void 0:a},Map.prototype.clear=function(){return 0===this.size?this:this.__ownerID?(this.size=0,this._root=null,this.__hash=void 0,this.__altered=!0,this):emptyMap()},Map.prototype.merge=function(){return mergeIntoMapWith(this,void 0,arguments)},Map.prototype.mergeWith=function(o){return mergeIntoMapWith(this,o,s.call(arguments,1))},Map.prototype.mergeIn=function(o){var i=s.call(arguments,1);return this.updateIn(o,emptyMap(),(function(s){return"function"==typeof s.merge?s.merge.apply(s,i):i[i.length-1]}))},Map.prototype.mergeDeep=function(){return mergeIntoMapWith(this,deepMerger,arguments)},Map.prototype.mergeDeepWith=function(o){var i=s.call(arguments,1);return mergeIntoMapWith(this,deepMergerWith(o),i)},Map.prototype.mergeDeepIn=function(o){var i=s.call(arguments,1);return this.updateIn(o,emptyMap(),(function(s){return"function"==typeof s.mergeDeep?s.mergeDeep.apply(s,i):i[i.length-1]}))},Map.prototype.sort=function(s){return OrderedMap(sortFactory(this,s))},Map.prototype.sortBy=function(s,o){return OrderedMap(sortFactory(this,o,s))},Map.prototype.withMutations=function(s){var o=this.asMutable();return s(o),o.wasAltered()?o.__ensureOwner(this.__ownerID):this},Map.prototype.asMutable=function(){return this.__ownerID?this:this.__ensureOwner(new OwnerID)},Map.prototype.asImmutable=function(){return this.__ensureOwner()},Map.prototype.wasAltered=function(){return this.__altered},Map.prototype.__iterator=function(s,o){return new MapIterator(this,s,o)},Map.prototype.__iterate=function(s,o){var i=this,a=0;return this._root&&this._root.iterate((function(o){return a++,s(o[1],o[0],i)}),o),a},Map.prototype.__ensureOwner=function(s){return s===this.__ownerID?this:s?makeMap(this.size,this._root,s,this.__hash):(this.__ownerID=s,this.__altered=!1,this)},Map.isMap=isMap;var Te,Re="@@__IMMUTABLE_MAP__@@",$e=Map.prototype;function ArrayMapNode(s,o){this.ownerID=s,this.entries=o}function BitmapIndexedNode(s,o,i){this.ownerID=s,this.bitmap=o,this.nodes=i}function HashArrayMapNode(s,o,i){this.ownerID=s,this.count=o,this.nodes=i}function HashCollisionNode(s,o,i){this.ownerID=s,this.keyHash=o,this.entries=i}function ValueNode(s,o,i){this.ownerID=s,this.keyHash=o,this.entry=i}function MapIterator(s,o,i){this._type=o,this._reverse=i,this._stack=s._root&&mapIteratorFrame(s._root)}function mapIteratorValue(s,o){return iteratorValue(s,o[0],o[1])}function mapIteratorFrame(s,o){return{node:s,index:0,__prev:o}}function makeMap(s,o,i,a){var u=Object.create($e);return u.size=s,u._root=o,u.__ownerID=i,u.__hash=a,u.__altered=!1,u}function emptyMap(){return Te||(Te=makeMap(0))}function updateMap(s,o,i){var a,u;if(s._root){var _=MakeRef(L),w=MakeRef(B);if(a=updateNode(s._root,s.__ownerID,0,void 0,o,i,_,w),!w.value)return s;u=s.size+(_.value?i===j?-1:1:0)}else{if(i===j)return s;u=1,a=new ArrayMapNode(s.__ownerID,[[o,i]])}return s.__ownerID?(s.size=u,s._root=a,s.__hash=void 0,s.__altered=!0,s):a?makeMap(u,a):emptyMap()}function updateNode(s,o,i,a,u,_,w,x){return s?s.update(o,i,a,u,_,w,x):_===j?s:(SetRef(x),SetRef(w),new ValueNode(o,a,[u,_]))}function isLeafNode(s){return s.constructor===ValueNode||s.constructor===HashCollisionNode}function mergeIntoNode(s,o,i,a,u){if(s.keyHash===a)return new HashCollisionNode(o,a,[s.entry,u]);var _,x=(0===i?s.keyHash:s.keyHash>>>i)&C,j=(0===i?a:a>>>i)&C;return new BitmapIndexedNode(o,1<>>=1)w[C]=1&i?o[_++]:void 0;return w[a]=u,new HashArrayMapNode(s,_+1,w)}function mergeIntoMapWith(s,o,i){for(var a=[],u=0;u>1&1431655765))+(s>>2&858993459))+(s>>4)&252645135,s+=s>>8,127&(s+=s>>16)}function setIn(s,o,i,a){var u=a?s:arrCopy(s);return u[o]=i,u}function spliceIn(s,o,i,a){var u=s.length+1;if(a&&o+1===u)return s[o]=i,s;for(var _=new Array(u),w=0,x=0;x=qe)return createNodes(s,C,a,u);var U=s&&s===this.ownerID,V=U?C:arrCopy(C);return $?x?L===B-1?V.pop():V[L]=V.pop():V[L]=[a,u]:V.push([a,u]),U?(this.entries=V,this):new ArrayMapNode(s,V)}},BitmapIndexedNode.prototype.get=function(s,o,i,a){void 0===o&&(o=hash(i));var u=1<<((0===s?o:o>>>s)&C),_=this.bitmap;return _&u?this.nodes[popCount(_&u-1)].get(s+w,o,i,a):a},BitmapIndexedNode.prototype.update=function(s,o,i,a,u,_,x){void 0===i&&(i=hash(a));var L=(0===o?i:i>>>o)&C,B=1<=ze)return expandNodes(s,z,$,L,Z);if(U&&!Z&&2===z.length&&isLeafNode(z[1^V]))return z[1^V];if(U&&Z&&1===z.length&&isLeafNode(Z))return Z;var ee=s&&s===this.ownerID,ie=U?Z?$:$^B:$|B,ae=U?Z?setIn(z,V,Z,ee):spliceOut(z,V,ee):spliceIn(z,V,Z,ee);return ee?(this.bitmap=ie,this.nodes=ae,this):new BitmapIndexedNode(s,ie,ae)},HashArrayMapNode.prototype.get=function(s,o,i,a){void 0===o&&(o=hash(i));var u=(0===s?o:o>>>s)&C,_=this.nodes[u];return _?_.get(s+w,o,i,a):a},HashArrayMapNode.prototype.update=function(s,o,i,a,u,_,x){void 0===i&&(i=hash(a));var L=(0===o?i:i>>>o)&C,B=u===j,$=this.nodes,U=$[L];if(B&&!U)return this;var V=updateNode(U,s,o+w,i,a,u,_,x);if(V===U)return this;var z=this.count;if(U){if(!V&&--z0&&a=0&&s>>o&C;if(a>=this.array.length)return new VNode([],s);var u,_=0===a;if(o>0){var x=this.array[a];if((u=x&&x.removeBefore(s,o-w,i))===x&&_)return this}if(_&&!u)return this;var j=editableVNode(this,s);if(!_)for(var L=0;L>>o&C;if(u>=this.array.length)return this;if(o>0){var _=this.array[u];if((a=_&&_.removeAfter(s,o-w,i))===_&&u===this.array.length-1)return this}var x=editableVNode(this,s);return x.array.splice(u+1),a&&(x.array[u]=a),x};var Xe,Qe,et={};function iterateList(s,o){var i=s._origin,a=s._capacity,u=getTailOffset(a),_=s._tail;return iterateNodeOrLeaf(s._root,s._level,0);function iterateNodeOrLeaf(s,o,i){return 0===o?iterateLeaf(s,i):iterateNode(s,o,i)}function iterateLeaf(s,w){var C=w===u?_&&_.array:s&&s.array,j=w>i?0:i-w,L=a-w;return L>x&&(L=x),function(){if(j===L)return et;var s=o?--L:j++;return C&&C[s]}}function iterateNode(s,u,_){var C,j=s&&s.array,L=_>i?0:i-_>>u,B=1+(a-_>>u);return B>x&&(B=x),function(){for(;;){if(C){var s=C();if(s!==et)return s;C=null}if(L===B)return et;var i=o?--B:L++;C=iterateNodeOrLeaf(j&&j[i],u-w,_+(i<=s.size||o<0)return s.withMutations((function(s){o<0?setListBounds(s,o).set(0,i):setListBounds(s,0,o+1).set(o,i)}));o+=s._origin;var a=s._tail,u=s._root,_=MakeRef(B);return o>=getTailOffset(s._capacity)?a=updateVNode(a,s.__ownerID,0,o,i,_):u=updateVNode(u,s.__ownerID,s._level,o,i,_),_.value?s.__ownerID?(s._root=u,s._tail=a,s.__hash=void 0,s.__altered=!0,s):makeList(s._origin,s._capacity,s._level,u,a):s}function updateVNode(s,o,i,a,u,_){var x,j=a>>>i&C,L=s&&j0){var B=s&&s.array[j],$=updateVNode(B,o,i-w,a,u,_);return $===B?s:((x=editableVNode(s,o)).array[j]=$,x)}return L&&s.array[j]===u?s:(SetRef(_),x=editableVNode(s,o),void 0===u&&j===x.array.length-1?x.array.pop():x.array[j]=u,x)}function editableVNode(s,o){return o&&s&&o===s.ownerID?s:new VNode(s?s.array.slice():[],o)}function listNodeFor(s,o){if(o>=getTailOffset(s._capacity))return s._tail;if(o<1<0;)i=i.array[o>>>a&C],a-=w;return i}}function setListBounds(s,o,i){void 0!==o&&(o|=0),void 0!==i&&(i|=0);var a=s.__ownerID||new OwnerID,u=s._origin,_=s._capacity,x=u+o,j=void 0===i?_:i<0?_+i:u+i;if(x===u&&j===_)return s;if(x>=j)return s.clear();for(var L=s._level,B=s._root,$=0;x+$<0;)B=new VNode(B&&B.array.length?[void 0,B]:[],a),$+=1<<(L+=w);$&&(x+=$,u+=$,j+=$,_+=$);for(var U=getTailOffset(_),V=getTailOffset(j);V>=1<U?new VNode([],a):z;if(z&&V>U&&x<_&&z.array.length){for(var Z=B=editableVNode(B,a),ee=L;ee>w;ee-=w){var ie=U>>>ee&C;Z=Z.array[ie]=editableVNode(Z.array[ie],a)}Z.array[U>>>w&C]=z}if(j<_&&(Y=Y&&Y.removeAfter(a,0,j)),x>=V)x-=V,j-=V,L=w,B=null,Y=Y&&Y.removeBefore(a,0,x);else if(x>u||V>>L&C;if(ae!==V>>>L&C)break;ae&&($+=(1<u&&(B=B.removeBefore(a,L,x-$)),B&&Vu&&(u=x.size),isIterable(w)||(x=x.map((function(s){return fromJS(s)}))),a.push(x)}return u>s.size&&(s=s.setSize(u)),mergeIntoCollectionWith(s,o,a)}function getTailOffset(s){return s>>w<=x&&w.size>=2*_.size?(a=(u=w.filter((function(s,o){return void 0!==s&&C!==o}))).toKeyedSeq().map((function(s){return s[0]})).flip().toMap(),s.__ownerID&&(a.__ownerID=u.__ownerID=s.__ownerID)):(a=_.remove(o),u=C===w.size-1?w.pop():w.set(C,void 0))}else if(L){if(i===w.get(C)[1])return s;a=_,u=w.set(C,[o,i])}else a=_.set(o,w.size),u=w.set(w.size,[o,i]);return s.__ownerID?(s.size=a.size,s._map=a,s._list=u,s.__hash=void 0,s):makeOrderedMap(a,u)}function ToKeyedSequence(s,o){this._iter=s,this._useKeys=o,this.size=s.size}function ToIndexedSequence(s){this._iter=s,this.size=s.size}function ToSetSequence(s){this._iter=s,this.size=s.size}function FromEntriesSequence(s){this._iter=s,this.size=s.size}function flipFactory(s){var o=makeSequence(s);return o._iter=s,o.size=s.size,o.flip=function(){return s},o.reverse=function(){var o=s.reverse.apply(this);return o.flip=function(){return s.reverse()},o},o.has=function(o){return s.includes(o)},o.includes=function(o){return s.has(o)},o.cacheResult=cacheResultThrough,o.__iterateUncached=function(o,i){var a=this;return s.__iterate((function(s,i){return!1!==o(i,s,a)}),i)},o.__iteratorUncached=function(o,i){if(o===V){var a=s.__iterator(o,i);return new Iterator((function(){var s=a.next();if(!s.done){var o=s.value[0];s.value[0]=s.value[1],s.value[1]=o}return s}))}return s.__iterator(o===U?$:U,i)},o}function mapFactory(s,o,i){var a=makeSequence(s);return a.size=s.size,a.has=function(o){return s.has(o)},a.get=function(a,u){var _=s.get(a,j);return _===j?u:o.call(i,_,a,s)},a.__iterateUncached=function(a,u){var _=this;return s.__iterate((function(s,u,w){return!1!==a(o.call(i,s,u,w),u,_)}),u)},a.__iteratorUncached=function(a,u){var _=s.__iterator(V,u);return new Iterator((function(){var u=_.next();if(u.done)return u;var w=u.value,x=w[0];return iteratorValue(a,x,o.call(i,w[1],x,s),u)}))},a}function reverseFactory(s,o){var i=makeSequence(s);return i._iter=s,i.size=s.size,i.reverse=function(){return s},s.flip&&(i.flip=function(){var o=flipFactory(s);return o.reverse=function(){return s.flip()},o}),i.get=function(i,a){return s.get(o?i:-1-i,a)},i.has=function(i){return s.has(o?i:-1-i)},i.includes=function(o){return s.includes(o)},i.cacheResult=cacheResultThrough,i.__iterate=function(o,i){var a=this;return s.__iterate((function(s,i){return o(s,i,a)}),!i)},i.__iterator=function(o,i){return s.__iterator(o,!i)},i}function filterFactory(s,o,i,a){var u=makeSequence(s);return a&&(u.has=function(a){var u=s.get(a,j);return u!==j&&!!o.call(i,u,a,s)},u.get=function(a,u){var _=s.get(a,j);return _!==j&&o.call(i,_,a,s)?_:u}),u.__iterateUncached=function(u,_){var w=this,x=0;return s.__iterate((function(s,_,C){if(o.call(i,s,_,C))return x++,u(s,a?_:x-1,w)}),_),x},u.__iteratorUncached=function(u,_){var w=s.__iterator(V,_),x=0;return new Iterator((function(){for(;;){var _=w.next();if(_.done)return _;var C=_.value,j=C[0],L=C[1];if(o.call(i,L,j,s))return iteratorValue(u,a?j:x++,L,_)}}))},u}function countByFactory(s,o,i){var a=Map().asMutable();return s.__iterate((function(u,_){a.update(o.call(i,u,_,s),0,(function(s){return s+1}))})),a.asImmutable()}function groupByFactory(s,o,i){var a=isKeyed(s),u=(isOrdered(s)?OrderedMap():Map()).asMutable();s.__iterate((function(_,w){u.update(o.call(i,_,w,s),(function(s){return(s=s||[]).push(a?[w,_]:_),s}))}));var _=iterableClass(s);return u.map((function(o){return reify(s,_(o))}))}function sliceFactory(s,o,i,a){var u=s.size;if(void 0!==o&&(o|=0),void 0!==i&&(i===1/0?i=u:i|=0),wholeSlice(o,i,u))return s;var _=resolveBegin(o,u),w=resolveEnd(i,u);if(_!=_||w!=w)return sliceFactory(s.toSeq().cacheResult(),o,i,a);var x,C=w-_;C==C&&(x=C<0?0:C);var j=makeSequence(s);return j.size=0===x?x:s.size&&x||void 0,!a&&isSeq(s)&&x>=0&&(j.get=function(o,i){return(o=wrapIndex(this,o))>=0&&ox)return iteratorDone();var s=u.next();return a||o===U?s:iteratorValue(o,C-1,o===$?void 0:s.value[1],s)}))},j}function takeWhileFactory(s,o,i){var a=makeSequence(s);return a.__iterateUncached=function(a,u){var _=this;if(u)return this.cacheResult().__iterate(a,u);var w=0;return s.__iterate((function(s,u,x){return o.call(i,s,u,x)&&++w&&a(s,u,_)})),w},a.__iteratorUncached=function(a,u){var _=this;if(u)return this.cacheResult().__iterator(a,u);var w=s.__iterator(V,u),x=!0;return new Iterator((function(){if(!x)return iteratorDone();var s=w.next();if(s.done)return s;var u=s.value,C=u[0],j=u[1];return o.call(i,j,C,_)?a===V?s:iteratorValue(a,C,j,s):(x=!1,iteratorDone())}))},a}function skipWhileFactory(s,o,i,a){var u=makeSequence(s);return u.__iterateUncached=function(u,_){var w=this;if(_)return this.cacheResult().__iterate(u,_);var x=!0,C=0;return s.__iterate((function(s,_,j){if(!x||!(x=o.call(i,s,_,j)))return C++,u(s,a?_:C-1,w)})),C},u.__iteratorUncached=function(u,_){var w=this;if(_)return this.cacheResult().__iterator(u,_);var x=s.__iterator(V,_),C=!0,j=0;return new Iterator((function(){var s,_,L;do{if((s=x.next()).done)return a||u===U?s:iteratorValue(u,j++,u===$?void 0:s.value[1],s);var B=s.value;_=B[0],L=B[1],C&&(C=o.call(i,L,_,w))}while(C);return u===V?s:iteratorValue(u,_,L,s)}))},u}function concatFactory(s,o){var i=isKeyed(s),a=[s].concat(o).map((function(s){return isIterable(s)?i&&(s=KeyedIterable(s)):s=i?keyedSeqFromValue(s):indexedSeqFromValue(Array.isArray(s)?s:[s]),s})).filter((function(s){return 0!==s.size}));if(0===a.length)return s;if(1===a.length){var u=a[0];if(u===s||i&&isKeyed(u)||isIndexed(s)&&isIndexed(u))return u}var _=new ArraySeq(a);return i?_=_.toKeyedSeq():isIndexed(s)||(_=_.toSetSeq()),(_=_.flatten(!0)).size=a.reduce((function(s,o){if(void 0!==s){var i=o.size;if(void 0!==i)return s+i}}),0),_}function flattenFactory(s,o,i){var a=makeSequence(s);return a.__iterateUncached=function(a,u){var _=0,w=!1;function flatDeep(s,x){var C=this;s.__iterate((function(s,u){return(!o||x0}function zipWithFactory(s,o,i){var a=makeSequence(s);return a.size=new ArraySeq(i).map((function(s){return s.size})).min(),a.__iterate=function(s,o){for(var i,a=this.__iterator(U,o),u=0;!(i=a.next()).done&&!1!==s(i.value,u++,this););return u},a.__iteratorUncached=function(s,a){var u=i.map((function(s){return s=Iterable(s),getIterator(a?s.reverse():s)})),_=0,w=!1;return new Iterator((function(){var i;return w||(i=u.map((function(s){return s.next()})),w=i.some((function(s){return s.done}))),w?iteratorDone():iteratorValue(s,_++,o.apply(null,i.map((function(s){return s.value}))))}))},a}function reify(s,o){return isSeq(s)?o:s.constructor(o)}function validateEntry(s){if(s!==Object(s))throw new TypeError("Expected [K, V] tuple: "+s)}function resolveSize(s){return assertNotInfinite(s.size),ensureSize(s)}function iterableClass(s){return isKeyed(s)?KeyedIterable:isIndexed(s)?IndexedIterable:SetIterable}function makeSequence(s){return Object.create((isKeyed(s)?KeyedSeq:isIndexed(s)?IndexedSeq:SetSeq).prototype)}function cacheResultThrough(){return this._iter.cacheResult?(this._iter.cacheResult(),this.size=this._iter.size,this):Seq.prototype.cacheResult.call(this)}function defaultComparator(s,o){return s>o?1:s=0;i--)o={value:arguments[i],next:o};return this.__ownerID?(this.size=s,this._head=o,this.__hash=void 0,this.__altered=!0,this):makeStack(s,o)},Stack.prototype.pushAll=function(s){if(0===(s=IndexedIterable(s)).size)return this;assertNotInfinite(s.size);var o=this.size,i=this._head;return s.reverse().forEach((function(s){o++,i={value:s,next:i}})),this.__ownerID?(this.size=o,this._head=i,this.__hash=void 0,this.__altered=!0,this):makeStack(o,i)},Stack.prototype.pop=function(){return this.slice(1)},Stack.prototype.unshift=function(){return this.push.apply(this,arguments)},Stack.prototype.unshiftAll=function(s){return this.pushAll(s)},Stack.prototype.shift=function(){return this.pop.apply(this,arguments)},Stack.prototype.clear=function(){return 0===this.size?this:this.__ownerID?(this.size=0,this._head=void 0,this.__hash=void 0,this.__altered=!0,this):emptyStack()},Stack.prototype.slice=function(s,o){if(wholeSlice(s,o,this.size))return this;var i=resolveBegin(s,this.size);if(resolveEnd(o,this.size)!==this.size)return IndexedCollection.prototype.slice.call(this,s,o);for(var a=this.size-i,u=this._head;i--;)u=u.next;return this.__ownerID?(this.size=a,this._head=u,this.__hash=void 0,this.__altered=!0,this):makeStack(a,u)},Stack.prototype.__ensureOwner=function(s){return s===this.__ownerID?this:s?makeStack(this.size,this._head,s,this.__hash):(this.__ownerID=s,this.__altered=!1,this)},Stack.prototype.__iterate=function(s,o){if(o)return this.reverse().__iterate(s);for(var i=0,a=this._head;a&&!1!==s(a.value,i++,this);)a=a.next;return i},Stack.prototype.__iterator=function(s,o){if(o)return this.reverse().__iterator(s);var i=0,a=this._head;return new Iterator((function(){if(a){var o=a.value;return a=a.next,iteratorValue(s,i++,o)}return iteratorDone()}))},Stack.isStack=isStack;var at,ct="@@__IMMUTABLE_STACK__@@",lt=Stack.prototype;function makeStack(s,o,i,a){var u=Object.create(lt);return u.size=s,u._head=o,u.__ownerID=i,u.__hash=a,u.__altered=!1,u}function emptyStack(){return at||(at=makeStack(0))}function mixin(s,o){var keyCopier=function(i){s.prototype[i]=o[i]};return Object.keys(o).forEach(keyCopier),Object.getOwnPropertySymbols&&Object.getOwnPropertySymbols(o).forEach(keyCopier),s}lt[ct]=!0,lt.withMutations=$e.withMutations,lt.asMutable=$e.asMutable,lt.asImmutable=$e.asImmutable,lt.wasAltered=$e.wasAltered,Iterable.Iterator=Iterator,mixin(Iterable,{toArray:function(){assertNotInfinite(this.size);var s=new Array(this.size||0);return this.valueSeq().__iterate((function(o,i){s[i]=o})),s},toIndexedSeq:function(){return new ToIndexedSequence(this)},toJS:function(){return this.toSeq().map((function(s){return s&&"function"==typeof s.toJS?s.toJS():s})).__toJS()},toJSON:function(){return this.toSeq().map((function(s){return s&&"function"==typeof s.toJSON?s.toJSON():s})).__toJS()},toKeyedSeq:function(){return new ToKeyedSequence(this,!0)},toMap:function(){return Map(this.toKeyedSeq())},toObject:function(){assertNotInfinite(this.size);var s={};return this.__iterate((function(o,i){s[i]=o})),s},toOrderedMap:function(){return OrderedMap(this.toKeyedSeq())},toOrderedSet:function(){return OrderedSet(isKeyed(this)?this.valueSeq():this)},toSet:function(){return Set(isKeyed(this)?this.valueSeq():this)},toSetSeq:function(){return new ToSetSequence(this)},toSeq:function(){return isIndexed(this)?this.toIndexedSeq():isKeyed(this)?this.toKeyedSeq():this.toSetSeq()},toStack:function(){return Stack(isKeyed(this)?this.valueSeq():this)},toList:function(){return List(isKeyed(this)?this.valueSeq():this)},toString:function(){return"[Iterable]"},__toString:function(s,o){return 0===this.size?s+o:s+" "+this.toSeq().map(this.__toStringMapper).join(", ")+" "+o},concat:function(){return reify(this,concatFactory(this,s.call(arguments,0)))},includes:function(s){return this.some((function(o){return is(o,s)}))},entries:function(){return this.__iterator(V)},every:function(s,o){assertNotInfinite(this.size);var i=!0;return this.__iterate((function(a,u,_){if(!s.call(o,a,u,_))return i=!1,!1})),i},filter:function(s,o){return reify(this,filterFactory(this,s,o,!0))},find:function(s,o,i){var a=this.findEntry(s,o);return a?a[1]:i},forEach:function(s,o){return assertNotInfinite(this.size),this.__iterate(o?s.bind(o):s)},join:function(s){assertNotInfinite(this.size),s=void 0!==s?""+s:",";var o="",i=!0;return this.__iterate((function(a){i?i=!1:o+=s,o+=null!=a?a.toString():""})),o},keys:function(){return this.__iterator($)},map:function(s,o){return reify(this,mapFactory(this,s,o))},reduce:function(s,o,i){var a,u;return assertNotInfinite(this.size),arguments.length<2?u=!0:a=o,this.__iterate((function(o,_,w){u?(u=!1,a=o):a=s.call(i,a,o,_,w)})),a},reduceRight:function(s,o,i){var a=this.toKeyedSeq().reverse();return a.reduce.apply(a,arguments)},reverse:function(){return reify(this,reverseFactory(this,!0))},slice:function(s,o){return reify(this,sliceFactory(this,s,o,!0))},some:function(s,o){return!this.every(not(s),o)},sort:function(s){return reify(this,sortFactory(this,s))},values:function(){return this.__iterator(U)},butLast:function(){return this.slice(0,-1)},isEmpty:function(){return void 0!==this.size?0===this.size:!this.some((function(){return!0}))},count:function(s,o){return ensureSize(s?this.toSeq().filter(s,o):this)},countBy:function(s,o){return countByFactory(this,s,o)},equals:function(s){return deepEqual(this,s)},entrySeq:function(){var s=this;if(s._cache)return new ArraySeq(s._cache);var o=s.toSeq().map(entryMapper).toIndexedSeq();return o.fromEntrySeq=function(){return s.toSeq()},o},filterNot:function(s,o){return this.filter(not(s),o)},findEntry:function(s,o,i){var a=i;return this.__iterate((function(i,u,_){if(s.call(o,i,u,_))return a=[u,i],!1})),a},findKey:function(s,o){var i=this.findEntry(s,o);return i&&i[0]},findLast:function(s,o,i){return this.toKeyedSeq().reverse().find(s,o,i)},findLastEntry:function(s,o,i){return this.toKeyedSeq().reverse().findEntry(s,o,i)},findLastKey:function(s,o){return this.toKeyedSeq().reverse().findKey(s,o)},first:function(){return this.find(returnTrue)},flatMap:function(s,o){return reify(this,flatMapFactory(this,s,o))},flatten:function(s){return reify(this,flattenFactory(this,s,!0))},fromEntrySeq:function(){return new FromEntriesSequence(this)},get:function(s,o){return this.find((function(o,i){return is(i,s)}),void 0,o)},getIn:function(s,o){for(var i,a=this,u=forceIterator(s);!(i=u.next()).done;){var _=i.value;if((a=a&&a.get?a.get(_,j):j)===j)return o}return a},groupBy:function(s,o){return groupByFactory(this,s,o)},has:function(s){return this.get(s,j)!==j},hasIn:function(s){return this.getIn(s,j)!==j},isSubset:function(s){return s="function"==typeof s.includes?s:Iterable(s),this.every((function(o){return s.includes(o)}))},isSuperset:function(s){return(s="function"==typeof s.isSubset?s:Iterable(s)).isSubset(this)},keyOf:function(s){return this.findKey((function(o){return is(o,s)}))},keySeq:function(){return this.toSeq().map(keyMapper).toIndexedSeq()},last:function(){return this.toSeq().reverse().first()},lastKeyOf:function(s){return this.toKeyedSeq().reverse().keyOf(s)},max:function(s){return maxFactory(this,s)},maxBy:function(s,o){return maxFactory(this,o,s)},min:function(s){return maxFactory(this,s?neg(s):defaultNegComparator)},minBy:function(s,o){return maxFactory(this,o?neg(o):defaultNegComparator,s)},rest:function(){return this.slice(1)},skip:function(s){return this.slice(Math.max(0,s))},skipLast:function(s){return reify(this,this.toSeq().reverse().skip(s).reverse())},skipWhile:function(s,o){return reify(this,skipWhileFactory(this,s,o,!0))},skipUntil:function(s,o){return this.skipWhile(not(s),o)},sortBy:function(s,o){return reify(this,sortFactory(this,o,s))},take:function(s){return this.slice(0,Math.max(0,s))},takeLast:function(s){return reify(this,this.toSeq().reverse().take(s).reverse())},takeWhile:function(s,o){return reify(this,takeWhileFactory(this,s,o))},takeUntil:function(s,o){return this.takeWhile(not(s),o)},valueSeq:function(){return this.toIndexedSeq()},hashCode:function(){return this.__hash||(this.__hash=hashIterable(this))}});var ut=Iterable.prototype;ut[o]=!0,ut[Z]=ut.values,ut.__toJS=ut.toArray,ut.__toStringMapper=quoteString,ut.inspect=ut.toSource=function(){return this.toString()},ut.chain=ut.flatMap,ut.contains=ut.includes,mixin(KeyedIterable,{flip:function(){return reify(this,flipFactory(this))},mapEntries:function(s,o){var i=this,a=0;return reify(this,this.toSeq().map((function(u,_){return s.call(o,[_,u],a++,i)})).fromEntrySeq())},mapKeys:function(s,o){var i=this;return reify(this,this.toSeq().flip().map((function(a,u){return s.call(o,a,u,i)})).flip())}});var pt=KeyedIterable.prototype;function keyMapper(s,o){return o}function entryMapper(s,o){return[o,s]}function not(s){return function(){return!s.apply(this,arguments)}}function neg(s){return function(){return-s.apply(this,arguments)}}function quoteString(s){return"string"==typeof s?JSON.stringify(s):String(s)}function defaultZipper(){return arrCopy(arguments)}function defaultNegComparator(s,o){return so?-1:0}function hashIterable(s){if(s.size===1/0)return 0;var o=isOrdered(s),i=isKeyed(s),a=o?1:0;return murmurHashOfSize(s.__iterate(i?o?function(s,o){a=31*a+hashMerge(hash(s),hash(o))|0}:function(s,o){a=a+hashMerge(hash(s),hash(o))|0}:o?function(s){a=31*a+hash(s)|0}:function(s){a=a+hash(s)|0}),a)}function murmurHashOfSize(s,o){return o=le(o,3432918353),o=le(o<<15|o>>>-15,461845907),o=le(o<<13|o>>>-13,5),o=le((o=o+3864292196^s)^o>>>16,2246822507),o=smi((o=le(o^o>>>13,3266489909))^o>>>16)}function hashMerge(s,o){return s^o+2654435769+(s<<6)+(s>>2)}return pt[i]=!0,pt[Z]=ut.entries,pt.__toJS=ut.toObject,pt.__toStringMapper=function(s,o){return JSON.stringify(o)+": "+quoteString(s)},mixin(IndexedIterable,{toKeyedSeq:function(){return new ToKeyedSequence(this,!1)},filter:function(s,o){return reify(this,filterFactory(this,s,o,!1))},findIndex:function(s,o){var i=this.findEntry(s,o);return i?i[0]:-1},indexOf:function(s){var o=this.keyOf(s);return void 0===o?-1:o},lastIndexOf:function(s){var o=this.lastKeyOf(s);return void 0===o?-1:o},reverse:function(){return reify(this,reverseFactory(this,!1))},slice:function(s,o){return reify(this,sliceFactory(this,s,o,!1))},splice:function(s,o){var i=arguments.length;if(o=Math.max(0|o,0),0===i||2===i&&!o)return this;s=resolveBegin(s,s<0?this.count():this.size);var a=this.slice(0,s);return reify(this,1===i?a:a.concat(arrCopy(arguments,2),this.slice(s+o)))},findLastIndex:function(s,o){var i=this.findLastEntry(s,o);return i?i[0]:-1},first:function(){return this.get(0)},flatten:function(s){return reify(this,flattenFactory(this,s,!1))},get:function(s,o){return(s=wrapIndex(this,s))<0||this.size===1/0||void 0!==this.size&&s>this.size?o:this.find((function(o,i){return i===s}),void 0,o)},has:function(s){return(s=wrapIndex(this,s))>=0&&(void 0!==this.size?this.size===1/0||s{"use strict";i(71340);var a=i(92046);s.exports=a.Object.assign},9957:(s,o,i)=>{"use strict";var a=Function.prototype.call,u=Object.prototype.hasOwnProperty,_=i(66743);s.exports=_.call(a,u)},9999:(s,o,i)=>{var a=i(37217),u=i(83729),_=i(16547),w=i(74733),x=i(43838),C=i(93290),j=i(23007),L=i(92271),B=i(48948),$=i(50002),U=i(83349),V=i(5861),z=i(76189),Y=i(77199),Z=i(35529),ee=i(56449),ie=i(3656),ae=i(87730),ce=i(23805),le=i(38440),pe=i(95950),de=i(37241),fe="[object Arguments]",ye="[object Function]",be="[object Object]",_e={};_e[fe]=_e["[object Array]"]=_e["[object ArrayBuffer]"]=_e["[object DataView]"]=_e["[object Boolean]"]=_e["[object Date]"]=_e["[object Float32Array]"]=_e["[object Float64Array]"]=_e["[object Int8Array]"]=_e["[object Int16Array]"]=_e["[object Int32Array]"]=_e["[object Map]"]=_e["[object Number]"]=_e[be]=_e["[object RegExp]"]=_e["[object Set]"]=_e["[object String]"]=_e["[object Symbol]"]=_e["[object Uint8Array]"]=_e["[object Uint8ClampedArray]"]=_e["[object Uint16Array]"]=_e["[object Uint32Array]"]=!0,_e["[object Error]"]=_e[ye]=_e["[object WeakMap]"]=!1,s.exports=function baseClone(s,o,i,Se,we,xe){var Pe,Te=1&o,Re=2&o,$e=4&o;if(i&&(Pe=we?i(s,Se,we,xe):i(s)),void 0!==Pe)return Pe;if(!ce(s))return s;var qe=ee(s);if(qe){if(Pe=z(s),!Te)return j(s,Pe)}else{var ze=V(s),We=ze==ye||"[object GeneratorFunction]"==ze;if(ie(s))return C(s,Te);if(ze==be||ze==fe||We&&!we){if(Pe=Re||We?{}:Z(s),!Te)return Re?B(s,x(Pe,s)):L(s,w(Pe,s))}else{if(!_e[ze])return we?s:{};Pe=Y(s,ze,Te)}}xe||(xe=new a);var He=xe.get(s);if(He)return He;xe.set(s,Pe),le(s)?s.forEach((function(a){Pe.add(baseClone(a,o,i,a,s,xe))})):ae(s)&&s.forEach((function(a,u){Pe.set(u,baseClone(a,o,i,u,s,xe))}));var Ye=qe?void 0:($e?Re?U:$:Re?de:pe)(s);return u(Ye||s,(function(a,u){Ye&&(a=s[u=a]),_(Pe,u,baseClone(a,o,i,u,s,xe))})),Pe}},10023:(s,o,i)=>{const a=i(6205),INTS=()=>[{type:a.RANGE,from:48,to:57}],WORDS=()=>[{type:a.CHAR,value:95},{type:a.RANGE,from:97,to:122},{type:a.RANGE,from:65,to:90}].concat(INTS()),WHITESPACE=()=>[{type:a.CHAR,value:9},{type:a.CHAR,value:10},{type:a.CHAR,value:11},{type:a.CHAR,value:12},{type:a.CHAR,value:13},{type:a.CHAR,value:32},{type:a.CHAR,value:160},{type:a.CHAR,value:5760},{type:a.RANGE,from:8192,to:8202},{type:a.CHAR,value:8232},{type:a.CHAR,value:8233},{type:a.CHAR,value:8239},{type:a.CHAR,value:8287},{type:a.CHAR,value:12288},{type:a.CHAR,value:65279}];o.words=()=>({type:a.SET,set:WORDS(),not:!1}),o.notWords=()=>({type:a.SET,set:WORDS(),not:!0}),o.ints=()=>({type:a.SET,set:INTS(),not:!1}),o.notInts=()=>({type:a.SET,set:INTS(),not:!0}),o.whitespace=()=>({type:a.SET,set:WHITESPACE(),not:!1}),o.notWhitespace=()=>({type:a.SET,set:WHITESPACE(),not:!0}),o.anyChar=()=>({type:a.SET,set:[{type:a.CHAR,value:10},{type:a.CHAR,value:13},{type:a.CHAR,value:8232},{type:a.CHAR,value:8233}],not:!0})},10043:(s,o,i)=>{"use strict";var a=i(54018),u=String,_=TypeError;s.exports=function(s){if(a(s))return s;throw new _("Can't set "+u(s)+" as a prototype")}},10076:s=>{"use strict";s.exports=Function.prototype.call},10124:(s,o,i)=>{var a=i(9325);s.exports=function(){return a.Date.now()}},10300:(s,o,i)=>{"use strict";var a=i(13930),u=i(82159),_=i(36624),w=i(4640),x=i(73448),C=TypeError;s.exports=function(s,o){var i=arguments.length<2?x(s):o;if(u(i))return _(a(i,s));throw new C(w(s)+" is not iterable")}},10316:(s,o,i)=>{const a=i(2404),u=i(55973),_=i(92340);class Element{constructor(s,o,i){o&&(this.meta=o),i&&(this.attributes=i),this.content=s}freeze(){Object.isFrozen(this)||(this._meta&&(this.meta.parent=this,this.meta.freeze()),this._attributes&&(this.attributes.parent=this,this.attributes.freeze()),this.children.forEach((s=>{s.parent=this,s.freeze()}),this),this.content&&Array.isArray(this.content)&&Object.freeze(this.content),Object.freeze(this))}primitive(){}clone(){const s=new this.constructor;return s.element=this.element,this.meta.length&&(s._meta=this.meta.clone()),this.attributes.length&&(s._attributes=this.attributes.clone()),this.content?this.content.clone?s.content=this.content.clone():Array.isArray(this.content)?s.content=this.content.map((s=>s.clone())):s.content=this.content:s.content=this.content,s}toValue(){return this.content instanceof Element?this.content.toValue():this.content instanceof u?{key:this.content.key.toValue(),value:this.content.value?this.content.value.toValue():void 0}:this.content&&this.content.map?this.content.map((s=>s.toValue()),this):this.content}toRef(s){if(""===this.id.toValue())throw Error("Cannot create reference to an element that does not contain an ID");const o=new this.RefElement(this.id.toValue());return s&&(o.path=s),o}findRecursive(...s){if(arguments.length>1&&!this.isFrozen)throw new Error("Cannot find recursive with multiple element names without first freezing the element. Call `element.freeze()`");const o=s.pop();let i=new _;const append=(s,o)=>(s.push(o),s),checkElement=(s,i)=>{i.element===o&&s.push(i);const a=i.findRecursive(o);return a&&a.reduce(append,s),i.content instanceof u&&(i.content.key&&checkElement(s,i.content.key),i.content.value&&checkElement(s,i.content.value)),s};return this.content&&(this.content.element&&checkElement(i,this.content),Array.isArray(this.content)&&this.content.reduce(checkElement,i)),s.isEmpty||(i=i.filter((o=>{let i=o.parents.map((s=>s.element));for(const o in s){const a=s[o],u=i.indexOf(a);if(-1===u)return!1;i=i.splice(0,u)}return!0}))),i}set(s){return this.content=s,this}equals(s){return a(this.toValue(),s)}getMetaProperty(s,o){if(!this.meta.hasKey(s)){if(this.isFrozen){const s=this.refract(o);return s.freeze(),s}this.meta.set(s,o)}return this.meta.get(s)}setMetaProperty(s,o){this.meta.set(s,o)}get element(){return this._storedElement||"element"}set element(s){this._storedElement=s}get content(){return this._content}set content(s){if(s instanceof Element)this._content=s;else if(s instanceof _)this.content=s.elements;else if("string"==typeof s||"number"==typeof s||"boolean"==typeof s||"null"===s||null==s)this._content=s;else if(s instanceof u)this._content=s;else if(Array.isArray(s))this._content=s.map(this.refract);else{if("object"!=typeof s)throw new Error("Cannot set content to given value");this._content=Object.keys(s).map((o=>new this.MemberElement(o,s[o])))}}get meta(){if(!this._meta){if(this.isFrozen){const s=new this.ObjectElement;return s.freeze(),s}this._meta=new this.ObjectElement}return this._meta}set meta(s){s instanceof this.ObjectElement?this._meta=s:this.meta.set(s||{})}get attributes(){if(!this._attributes){if(this.isFrozen){const s=new this.ObjectElement;return s.freeze(),s}this._attributes=new this.ObjectElement}return this._attributes}set attributes(s){s instanceof this.ObjectElement?this._attributes=s:this.attributes.set(s||{})}get id(){return this.getMetaProperty("id","")}set id(s){this.setMetaProperty("id",s)}get classes(){return this.getMetaProperty("classes",[])}set classes(s){this.setMetaProperty("classes",s)}get title(){return this.getMetaProperty("title","")}set title(s){this.setMetaProperty("title",s)}get description(){return this.getMetaProperty("description","")}set description(s){this.setMetaProperty("description",s)}get links(){return this.getMetaProperty("links",[])}set links(s){this.setMetaProperty("links",s)}get isFrozen(){return Object.isFrozen(this)}get parents(){let{parent:s}=this;const o=new _;for(;s;)o.push(s),s=s.parent;return o}get children(){if(Array.isArray(this.content))return new _(this.content);if(this.content instanceof u){const s=new _([this.content.key]);return this.content.value&&s.push(this.content.value),s}return this.content instanceof Element?new _([this.content]):new _}get recursiveChildren(){const s=new _;return this.children.forEach((o=>{s.push(o),o.recursiveChildren.forEach((o=>{s.push(o)}))})),s}}s.exports=Element},10392:s=>{s.exports=function getValue(s,o){return null==s?void 0:s[o]}},10487:(s,o,i)=>{"use strict";var a=i(96897),u=i(30655),_=i(73126),w=i(12205);s.exports=function callBind(s){var o=_(arguments),i=s.length-(arguments.length-1);return a(o,1+(i>0?i:0),!0)},u?u(s.exports,"apply",{value:w}):s.exports.apply=w},10776:(s,o,i)=>{var a=i(30756),u=i(95950);s.exports=function getMatchData(s){for(var o=u(s),i=o.length;i--;){var _=o[i],w=s[_];o[i]=[_,w,a(w)]}return o}},10866:(s,o,i)=>{const a=i(6048),u=i(92340);class ObjectSlice extends u{map(s,o){return this.elements.map((i=>s.bind(o)(i.value,i.key,i)))}filter(s,o){return new ObjectSlice(this.elements.filter((i=>s.bind(o)(i.value,i.key,i))))}reject(s,o){return this.filter(a(s.bind(o)))}forEach(s,o){return this.elements.forEach(((i,a)=>{s.bind(o)(i.value,i.key,i,a)}))}keys(){return this.map(((s,o)=>o.toValue()))}values(){return this.map((s=>s.toValue()))}}s.exports=ObjectSlice},11002:s=>{"use strict";s.exports=Function.prototype.apply},11042:(s,o,i)=>{"use strict";var a=i(85582),u=i(1907),_=i(24443),w=i(87170),x=i(36624),C=u([].concat);s.exports=a("Reflect","ownKeys")||function ownKeys(s){var o=_.f(x(s)),i=w.f;return i?C(o,i(s)):o}},11091:(s,o,i)=>{"use strict";var a=i(45951),u=i(76024),_=i(92361),w=i(62250),x=i(13846).f,C=i(7463),j=i(92046),L=i(28311),B=i(61626),$=i(49724);i(36128);var wrapConstructor=function(s){var Wrapper=function(o,i,a){if(this instanceof Wrapper){switch(arguments.length){case 0:return new s;case 1:return new s(o);case 2:return new s(o,i)}return new s(o,i,a)}return u(s,this,arguments)};return Wrapper.prototype=s.prototype,Wrapper};s.exports=function(s,o){var i,u,U,V,z,Y,Z,ee,ie,ae=s.target,ce=s.global,le=s.stat,pe=s.proto,de=ce?a:le?a[ae]:a[ae]&&a[ae].prototype,fe=ce?j:j[ae]||B(j,ae,{})[ae],ye=fe.prototype;for(V in o)u=!(i=C(ce?V:ae+(le?".":"#")+V,s.forced))&&de&&$(de,V),Y=fe[V],u&&(Z=s.dontCallGetSet?(ie=x(de,V))&&ie.value:de[V]),z=u&&Z?Z:o[V],(i||pe||typeof Y!=typeof z)&&(ee=s.bind&&u?L(z,a):s.wrap&&u?wrapConstructor(z):pe&&w(z)?_(z):z,(s.sham||z&&z.sham||Y&&Y.sham)&&B(ee,"sham",!0),B(fe,V,ee),pe&&($(j,U=ae+"Prototype")||B(j,U,{}),B(j[U],V,z),s.real&&ye&&(i||!ye[V])&&B(ye,V,z)))}},11287:s=>{s.exports=function getHolder(s){return s.placeholder}},11331:(s,o,i)=>{var a=i(72552),u=i(28879),_=i(40346),w=Function.prototype,x=Object.prototype,C=w.toString,j=x.hasOwnProperty,L=C.call(Object);s.exports=function isPlainObject(s){if(!_(s)||"[object Object]"!=a(s))return!1;var o=u(s);if(null===o)return!0;var i=j.call(o,"constructor")&&o.constructor;return"function"==typeof i&&i instanceof i&&C.call(i)==L}},11470:(s,o,i)=>{"use strict";var a=i(1907),u=i(65482),_=i(90160),w=i(74239),x=a("".charAt),C=a("".charCodeAt),j=a("".slice),createMethod=function(s){return function(o,i){var a,L,B=_(w(o)),$=u(i),U=B.length;return $<0||$>=U?s?"":void 0:(a=C(B,$))<55296||a>56319||$+1===U||(L=C(B,$+1))<56320||L>57343?s?x(B,$):a:s?j(B,$,$+2):L-56320+(a-55296<<10)+65536}};s.exports={codeAt:createMethod(!1),charAt:createMethod(!0)}},11842:(s,o,i)=>{var a=i(82819),u=i(9325);s.exports=function createBind(s,o,i){var _=1&o,w=a(s);return function wrapper(){return(this&&this!==u&&this instanceof wrapper?w:s).apply(_?i:this,arguments)}}},12205:(s,o,i)=>{"use strict";var a=i(66743),u=i(11002),_=i(13144);s.exports=function applyBind(){return _(a,u,arguments)}},12242:(s,o,i)=>{const a=i(10316);s.exports=class BooleanElement extends a{constructor(s,o,i){super(s,o,i),this.element="boolean"}primitive(){return"boolean"}}},12507:(s,o,i)=>{var a=i(28754),u=i(49698),_=i(63912),w=i(13222);s.exports=function createCaseFirst(s){return function(o){o=w(o);var i=u(o)?_(o):void 0,x=i?i[0]:o.charAt(0),C=i?a(i,1).join(""):o.slice(1);return x[s]()+C}}},12560:(s,o,i)=>{"use strict";i(99363);var a=i(19287),u=i(45951),_=i(14840),w=i(93742);for(var x in a)_(u[x],x),w[x]=w.Array},12651:(s,o,i)=>{var a=i(74218);s.exports=function getMapData(s,o){var i=s.__data__;return a(o)?i["string"==typeof o?"string":"hash"]:i.map}},12749:(s,o,i)=>{var a=i(81042),u=Object.prototype.hasOwnProperty;s.exports=function hashHas(s){var o=this.__data__;return a?void 0!==o[s]:u.call(o,s)}},13144:(s,o,i)=>{"use strict";var a=i(66743),u=i(11002),_=i(10076),w=i(47119);s.exports=w||a.call(_,u)},13222:(s,o,i)=>{var a=i(77556);s.exports=function toString(s){return null==s?"":a(s)}},13846:(s,o,i)=>{"use strict";var a=i(39447),u=i(13930),_=i(22574),w=i(75817),x=i(4993),C=i(70470),j=i(49724),L=i(73648),B=Object.getOwnPropertyDescriptor;o.f=a?B:function getOwnPropertyDescriptor(s,o){if(s=x(s),o=C(o),L)try{return B(s,o)}catch(s){}if(j(s,o))return w(!u(_.f,s,o),s[o])}},13930:(s,o,i)=>{"use strict";var a=i(41505),u=Function.prototype.call;s.exports=a?u.bind(u):function(){return u.apply(u,arguments)}},14248:s=>{s.exports=function arraySome(s,o){for(var i=-1,a=null==s?0:s.length;++i{s.exports=function arrayPush(s,o){for(var i=-1,a=o.length,u=s.length;++i{const a=i(10316);s.exports=class RefElement extends a{constructor(s,o,i){super(s||[],o,i),this.element="ref",this.path||(this.path="element")}get path(){return this.attributes.get("path")}set path(s){this.attributes.set("path",s)}}},14744:s=>{"use strict";var o=function isMergeableObject(s){return function isNonNullObject(s){return!!s&&"object"==typeof s}(s)&&!function isSpecial(s){var o=Object.prototype.toString.call(s);return"[object RegExp]"===o||"[object Date]"===o||function isReactElement(s){return s.$$typeof===i}(s)}(s)};var i="function"==typeof Symbol&&Symbol.for?Symbol.for("react.element"):60103;function cloneUnlessOtherwiseSpecified(s,o){return!1!==o.clone&&o.isMergeableObject(s)?deepmerge(function emptyTarget(s){return Array.isArray(s)?[]:{}}(s),s,o):s}function defaultArrayMerge(s,o,i){return s.concat(o).map((function(s){return cloneUnlessOtherwiseSpecified(s,i)}))}function getKeys(s){return Object.keys(s).concat(function getEnumerableOwnPropertySymbols(s){return Object.getOwnPropertySymbols?Object.getOwnPropertySymbols(s).filter((function(o){return Object.propertyIsEnumerable.call(s,o)})):[]}(s))}function propertyIsOnObject(s,o){try{return o in s}catch(s){return!1}}function mergeObject(s,o,i){var a={};return i.isMergeableObject(s)&&getKeys(s).forEach((function(o){a[o]=cloneUnlessOtherwiseSpecified(s[o],i)})),getKeys(o).forEach((function(u){(function propertyIsUnsafe(s,o){return propertyIsOnObject(s,o)&&!(Object.hasOwnProperty.call(s,o)&&Object.propertyIsEnumerable.call(s,o))})(s,u)||(propertyIsOnObject(s,u)&&i.isMergeableObject(o[u])?a[u]=function getMergeFunction(s,o){if(!o.customMerge)return deepmerge;var i=o.customMerge(s);return"function"==typeof i?i:deepmerge}(u,i)(s[u],o[u],i):a[u]=cloneUnlessOtherwiseSpecified(o[u],i))})),a}function deepmerge(s,i,a){(a=a||{}).arrayMerge=a.arrayMerge||defaultArrayMerge,a.isMergeableObject=a.isMergeableObject||o,a.cloneUnlessOtherwiseSpecified=cloneUnlessOtherwiseSpecified;var u=Array.isArray(i);return u===Array.isArray(s)?u?a.arrayMerge(s,i,a):mergeObject(s,i,a):cloneUnlessOtherwiseSpecified(i,a)}deepmerge.all=function deepmergeAll(s,o){if(!Array.isArray(s))throw new Error("first argument should be an array");return s.reduce((function(s,i){return deepmerge(s,i,o)}),{})};var a=deepmerge;s.exports=a},14792:(s,o,i)=>{var a=i(13222),u=i(55808);s.exports=function capitalize(s){return u(a(s).toLowerCase())}},14840:(s,o,i)=>{"use strict";var a=i(52623),u=i(74284).f,_=i(61626),w=i(49724),x=i(54878),C=i(76264)("toStringTag");s.exports=function(s,o,i,j){var L=i?s:s&&s.prototype;L&&(w(L,C)||u(L,C,{configurable:!0,value:o}),j&&!a&&_(L,"toString",x))}},14974:s=>{s.exports=function safeGet(s,o){if(("constructor"!==o||"function"!=typeof s[o])&&"__proto__"!=o)return s[o]}},15287:(s,o)=>{"use strict";var i=Symbol.for("react.element"),a=Symbol.for("react.portal"),u=Symbol.for("react.fragment"),_=Symbol.for("react.strict_mode"),w=Symbol.for("react.profiler"),x=Symbol.for("react.provider"),C=Symbol.for("react.context"),j=Symbol.for("react.forward_ref"),L=Symbol.for("react.suspense"),B=Symbol.for("react.memo"),$=Symbol.for("react.lazy"),U=Symbol.iterator;var V={isMounted:function(){return!1},enqueueForceUpdate:function(){},enqueueReplaceState:function(){},enqueueSetState:function(){}},z=Object.assign,Y={};function E(s,o,i){this.props=s,this.context=o,this.refs=Y,this.updater=i||V}function F(){}function G(s,o,i){this.props=s,this.context=o,this.refs=Y,this.updater=i||V}E.prototype.isReactComponent={},E.prototype.setState=function(s,o){if("object"!=typeof s&&"function"!=typeof s&&null!=s)throw Error("setState(...): takes an object of state variables to update or a function which returns an object of state variables.");this.updater.enqueueSetState(this,s,o,"setState")},E.prototype.forceUpdate=function(s){this.updater.enqueueForceUpdate(this,s,"forceUpdate")},F.prototype=E.prototype;var Z=G.prototype=new F;Z.constructor=G,z(Z,E.prototype),Z.isPureReactComponent=!0;var ee=Array.isArray,ie=Object.prototype.hasOwnProperty,ae={current:null},ce={key:!0,ref:!0,__self:!0,__source:!0};function M(s,o,a){var u,_={},w=null,x=null;if(null!=o)for(u in void 0!==o.ref&&(x=o.ref),void 0!==o.key&&(w=""+o.key),o)ie.call(o,u)&&!ce.hasOwnProperty(u)&&(_[u]=o[u]);var C=arguments.length-2;if(1===C)_.children=a;else if(1{var a=i(96131);s.exports=function arrayIncludes(s,o){return!!(null==s?0:s.length)&&a(s,o,0)>-1}},15340:()=>{},15377:(s,o,i)=>{"use strict";var a=i(92861).Buffer,u=i(64634),_=i(74372),w=ArrayBuffer.isView||function isView(s){try{return _(s),!0}catch(s){return!1}},x="undefined"!=typeof Uint8Array,C="undefined"!=typeof ArrayBuffer&&"undefined"!=typeof Uint8Array,j=C&&(a.prototype instanceof Uint8Array||a.TYPED_ARRAY_SUPPORT);s.exports=function toBuffer(s,o){if(s instanceof a)return s;if("string"==typeof s)return a.from(s,o);if(C&&w(s)){if(0===s.byteLength)return a.alloc(0);if(j){var i=a.from(s.buffer,s.byteOffset,s.byteLength);if(i.byteLength===s.byteLength)return i}var _=s instanceof Uint8Array?s:new Uint8Array(s.buffer,s.byteOffset,s.byteLength),L=a.from(_);if(L.length===s.byteLength)return L}if(x&&s instanceof Uint8Array)return a.from(s);var B=u(s);if(B)for(var $=0;$255||~~U!==U)throw new RangeError("Array items must be numbers in the range 0-255.")}if(B||a.isBuffer(s)&&s.constructor&&"function"==typeof s.constructor.isBuffer&&s.constructor.isBuffer(s))return a.from(s);throw new TypeError('The "data" argument must be a string, an Array, a Buffer, a Uint8Array, or a DataView.')}},15389:(s,o,i)=>{var a=i(93663),u=i(87978),_=i(83488),w=i(56449),x=i(50583);s.exports=function baseIteratee(s){return"function"==typeof s?s:null==s?_:"object"==typeof s?w(s)?u(s[0],s[1]):a(s):x(s)}},15972:(s,o,i)=>{"use strict";var a=i(49724),u=i(62250),_=i(39298),w=i(92522),x=i(57382),C=w("IE_PROTO"),j=Object,L=j.prototype;s.exports=x?j.getPrototypeOf:function(s){var o=_(s);if(a(o,C))return o[C];var i=o.constructor;return u(i)&&o instanceof i?i.prototype:o instanceof j?L:null}},16038:(s,o,i)=>{var a=i(5861),u=i(40346);s.exports=function baseIsSet(s){return u(s)&&"[object Set]"==a(s)}},16426:s=>{s.exports=function(){var s=document.getSelection();if(!s.rangeCount)return function(){};for(var o=document.activeElement,i=[],a=0;a{var a=i(43360),u=i(75288),_=Object.prototype.hasOwnProperty;s.exports=function assignValue(s,o,i){var w=s[o];_.call(s,o)&&u(w,i)&&(void 0!==i||o in s)||a(s,o,i)}},16708:(s,o,i)=>{"use strict";var a,u=i(65606);function CorkedRequest(s){var o=this;this.next=null,this.entry=null,this.finish=function(){!function onCorkedFinish(s,o,i){var a=s.entry;s.entry=null;for(;a;){var u=a.callback;o.pendingcb--,u(i),a=a.next}o.corkedRequestsFree.next=s}(o,s)}}s.exports=Writable,Writable.WritableState=WritableState;var _={deprecate:i(94643)},w=i(40345),x=i(48287).Buffer,C=(void 0!==i.g?i.g:"undefined"!=typeof window?window:"undefined"!=typeof self?self:{}).Uint8Array||function(){};var j,L=i(75896),B=i(65291).getHighWaterMark,$=i(86048).F,U=$.ERR_INVALID_ARG_TYPE,V=$.ERR_METHOD_NOT_IMPLEMENTED,z=$.ERR_MULTIPLE_CALLBACK,Y=$.ERR_STREAM_CANNOT_PIPE,Z=$.ERR_STREAM_DESTROYED,ee=$.ERR_STREAM_NULL_VALUES,ie=$.ERR_STREAM_WRITE_AFTER_END,ae=$.ERR_UNKNOWN_ENCODING,ce=L.errorOrDestroy;function nop(){}function WritableState(s,o,_){a=a||i(25382),s=s||{},"boolean"!=typeof _&&(_=o instanceof a),this.objectMode=!!s.objectMode,_&&(this.objectMode=this.objectMode||!!s.writableObjectMode),this.highWaterMark=B(this,s,"writableHighWaterMark",_),this.finalCalled=!1,this.needDrain=!1,this.ending=!1,this.ended=!1,this.finished=!1,this.destroyed=!1;var w=!1===s.decodeStrings;this.decodeStrings=!w,this.defaultEncoding=s.defaultEncoding||"utf8",this.length=0,this.writing=!1,this.corked=0,this.sync=!0,this.bufferProcessing=!1,this.onwrite=function(s){!function onwrite(s,o){var i=s._writableState,a=i.sync,_=i.writecb;if("function"!=typeof _)throw new z;if(function onwriteStateUpdate(s){s.writing=!1,s.writecb=null,s.length-=s.writelen,s.writelen=0}(i),o)!function onwriteError(s,o,i,a,_){--o.pendingcb,i?(u.nextTick(_,a),u.nextTick(finishMaybe,s,o),s._writableState.errorEmitted=!0,ce(s,a)):(_(a),s._writableState.errorEmitted=!0,ce(s,a),finishMaybe(s,o))}(s,i,a,o,_);else{var w=needFinish(i)||s.destroyed;w||i.corked||i.bufferProcessing||!i.bufferedRequest||clearBuffer(s,i),a?u.nextTick(afterWrite,s,i,w,_):afterWrite(s,i,w,_)}}(o,s)},this.writecb=null,this.writelen=0,this.bufferedRequest=null,this.lastBufferedRequest=null,this.pendingcb=0,this.prefinished=!1,this.errorEmitted=!1,this.emitClose=!1!==s.emitClose,this.autoDestroy=!!s.autoDestroy,this.bufferedRequestCount=0,this.corkedRequestsFree=new CorkedRequest(this)}function Writable(s){var o=this instanceof(a=a||i(25382));if(!o&&!j.call(Writable,this))return new Writable(s);this._writableState=new WritableState(s,this,o),this.writable=!0,s&&("function"==typeof s.write&&(this._write=s.write),"function"==typeof s.writev&&(this._writev=s.writev),"function"==typeof s.destroy&&(this._destroy=s.destroy),"function"==typeof s.final&&(this._final=s.final)),w.call(this)}function doWrite(s,o,i,a,u,_,w){o.writelen=a,o.writecb=w,o.writing=!0,o.sync=!0,o.destroyed?o.onwrite(new Z("write")):i?s._writev(u,o.onwrite):s._write(u,_,o.onwrite),o.sync=!1}function afterWrite(s,o,i,a){i||function onwriteDrain(s,o){0===o.length&&o.needDrain&&(o.needDrain=!1,s.emit("drain"))}(s,o),o.pendingcb--,a(),finishMaybe(s,o)}function clearBuffer(s,o){o.bufferProcessing=!0;var i=o.bufferedRequest;if(s._writev&&i&&i.next){var a=o.bufferedRequestCount,u=new Array(a),_=o.corkedRequestsFree;_.entry=i;for(var w=0,x=!0;i;)u[w]=i,i.isBuf||(x=!1),i=i.next,w+=1;u.allBuffers=x,doWrite(s,o,!0,o.length,u,"",_.finish),o.pendingcb++,o.lastBufferedRequest=null,_.next?(o.corkedRequestsFree=_.next,_.next=null):o.corkedRequestsFree=new CorkedRequest(o),o.bufferedRequestCount=0}else{for(;i;){var C=i.chunk,j=i.encoding,L=i.callback;if(doWrite(s,o,!1,o.objectMode?1:C.length,C,j,L),i=i.next,o.bufferedRequestCount--,o.writing)break}null===i&&(o.lastBufferedRequest=null)}o.bufferedRequest=i,o.bufferProcessing=!1}function needFinish(s){return s.ending&&0===s.length&&null===s.bufferedRequest&&!s.finished&&!s.writing}function callFinal(s,o){s._final((function(i){o.pendingcb--,i&&ce(s,i),o.prefinished=!0,s.emit("prefinish"),finishMaybe(s,o)}))}function finishMaybe(s,o){var i=needFinish(o);if(i&&(function prefinish(s,o){o.prefinished||o.finalCalled||("function"!=typeof s._final||o.destroyed?(o.prefinished=!0,s.emit("prefinish")):(o.pendingcb++,o.finalCalled=!0,u.nextTick(callFinal,s,o)))}(s,o),0===o.pendingcb&&(o.finished=!0,s.emit("finish"),o.autoDestroy))){var a=s._readableState;(!a||a.autoDestroy&&a.endEmitted)&&s.destroy()}return i}i(56698)(Writable,w),WritableState.prototype.getBuffer=function getBuffer(){for(var s=this.bufferedRequest,o=[];s;)o.push(s),s=s.next;return o},function(){try{Object.defineProperty(WritableState.prototype,"buffer",{get:_.deprecate((function writableStateBufferGetter(){return this.getBuffer()}),"_writableState.buffer is deprecated. Use _writableState.getBuffer instead.","DEP0003")})}catch(s){}}(),"function"==typeof Symbol&&Symbol.hasInstance&&"function"==typeof Function.prototype[Symbol.hasInstance]?(j=Function.prototype[Symbol.hasInstance],Object.defineProperty(Writable,Symbol.hasInstance,{value:function value(s){return!!j.call(this,s)||this===Writable&&(s&&s._writableState instanceof WritableState)}})):j=function realHasInstance(s){return s instanceof this},Writable.prototype.pipe=function(){ce(this,new Y)},Writable.prototype.write=function(s,o,i){var a=this._writableState,_=!1,w=!a.objectMode&&function _isUint8Array(s){return x.isBuffer(s)||s instanceof C}(s);return w&&!x.isBuffer(s)&&(s=function _uint8ArrayToBuffer(s){return x.from(s)}(s)),"function"==typeof o&&(i=o,o=null),w?o="buffer":o||(o=a.defaultEncoding),"function"!=typeof i&&(i=nop),a.ending?function writeAfterEnd(s,o){var i=new ie;ce(s,i),u.nextTick(o,i)}(this,i):(w||function validChunk(s,o,i,a){var _;return null===i?_=new ee:"string"==typeof i||o.objectMode||(_=new U("chunk",["string","Buffer"],i)),!_||(ce(s,_),u.nextTick(a,_),!1)}(this,a,s,i))&&(a.pendingcb++,_=function writeOrBuffer(s,o,i,a,u,_){if(!i){var w=function decodeChunk(s,o,i){s.objectMode||!1===s.decodeStrings||"string"!=typeof o||(o=x.from(o,i));return o}(o,a,u);a!==w&&(i=!0,u="buffer",a=w)}var C=o.objectMode?1:a.length;o.length+=C;var j=o.length-1))throw new ae(s);return this._writableState.defaultEncoding=s,this},Object.defineProperty(Writable.prototype,"writableBuffer",{enumerable:!1,get:function get(){return this._writableState&&this._writableState.getBuffer()}}),Object.defineProperty(Writable.prototype,"writableHighWaterMark",{enumerable:!1,get:function get(){return this._writableState.highWaterMark}}),Writable.prototype._write=function(s,o,i){i(new V("_write()"))},Writable.prototype._writev=null,Writable.prototype.end=function(s,o,i){var a=this._writableState;return"function"==typeof s?(i=s,s=null,o=null):"function"==typeof o&&(i=o,o=null),null!=s&&this.write(s,o),a.corked&&(a.corked=1,this.uncork()),a.ending||function endWritable(s,o,i){o.ending=!0,finishMaybe(s,o),i&&(o.finished?u.nextTick(i):s.once("finish",i));o.ended=!0,s.writable=!1}(this,a,i),this},Object.defineProperty(Writable.prototype,"writableLength",{enumerable:!1,get:function get(){return this._writableState.length}}),Object.defineProperty(Writable.prototype,"destroyed",{enumerable:!1,get:function get(){return void 0!==this._writableState&&this._writableState.destroyed},set:function set(s){this._writableState&&(this._writableState.destroyed=s)}}),Writable.prototype.destroy=L.destroy,Writable.prototype._undestroy=L.undestroy,Writable.prototype._destroy=function(s,o){o(s)}},16946:(s,o,i)=>{"use strict";var a=i(1907),u=i(98828),_=i(45807),w=Object,x=a("".split);s.exports=u((function(){return!w("z").propertyIsEnumerable(0)}))?function(s){return"String"===_(s)?x(s,""):w(s)}:w},16962:(s,o)=>{o.aliasToReal={each:"forEach",eachRight:"forEachRight",entries:"toPairs",entriesIn:"toPairsIn",extend:"assignIn",extendAll:"assignInAll",extendAllWith:"assignInAllWith",extendWith:"assignInWith",first:"head",conforms:"conformsTo",matches:"isMatch",property:"get",__:"placeholder",F:"stubFalse",T:"stubTrue",all:"every",allPass:"overEvery",always:"constant",any:"some",anyPass:"overSome",apply:"spread",assoc:"set",assocPath:"set",complement:"negate",compose:"flowRight",contains:"includes",dissoc:"unset",dissocPath:"unset",dropLast:"dropRight",dropLastWhile:"dropRightWhile",equals:"isEqual",identical:"eq",indexBy:"keyBy",init:"initial",invertObj:"invert",juxt:"over",omitAll:"omit",nAry:"ary",path:"get",pathEq:"matchesProperty",pathOr:"getOr",paths:"at",pickAll:"pick",pipe:"flow",pluck:"map",prop:"get",propEq:"matchesProperty",propOr:"getOr",props:"at",symmetricDifference:"xor",symmetricDifferenceBy:"xorBy",symmetricDifferenceWith:"xorWith",takeLast:"takeRight",takeLastWhile:"takeRightWhile",unapply:"rest",unnest:"flatten",useWith:"overArgs",where:"conformsTo",whereEq:"isMatch",zipObj:"zipObject"},o.aryMethod={1:["assignAll","assignInAll","attempt","castArray","ceil","create","curry","curryRight","defaultsAll","defaultsDeepAll","floor","flow","flowRight","fromPairs","invert","iteratee","memoize","method","mergeAll","methodOf","mixin","nthArg","over","overEvery","overSome","rest","reverse","round","runInContext","spread","template","trim","trimEnd","trimStart","uniqueId","words","zipAll"],2:["add","after","ary","assign","assignAllWith","assignIn","assignInAllWith","at","before","bind","bindAll","bindKey","chunk","cloneDeepWith","cloneWith","concat","conformsTo","countBy","curryN","curryRightN","debounce","defaults","defaultsDeep","defaultTo","delay","difference","divide","drop","dropRight","dropRightWhile","dropWhile","endsWith","eq","every","filter","find","findIndex","findKey","findLast","findLastIndex","findLastKey","flatMap","flatMapDeep","flattenDepth","forEach","forEachRight","forIn","forInRight","forOwn","forOwnRight","get","groupBy","gt","gte","has","hasIn","includes","indexOf","intersection","invertBy","invoke","invokeMap","isEqual","isMatch","join","keyBy","lastIndexOf","lt","lte","map","mapKeys","mapValues","matchesProperty","maxBy","meanBy","merge","mergeAllWith","minBy","multiply","nth","omit","omitBy","overArgs","pad","padEnd","padStart","parseInt","partial","partialRight","partition","pick","pickBy","propertyOf","pull","pullAll","pullAt","random","range","rangeRight","rearg","reject","remove","repeat","restFrom","result","sampleSize","some","sortBy","sortedIndex","sortedIndexOf","sortedLastIndex","sortedLastIndexOf","sortedUniqBy","split","spreadFrom","startsWith","subtract","sumBy","take","takeRight","takeRightWhile","takeWhile","tap","throttle","thru","times","trimChars","trimCharsEnd","trimCharsStart","truncate","union","uniqBy","uniqWith","unset","unzipWith","without","wrap","xor","zip","zipObject","zipObjectDeep"],3:["assignInWith","assignWith","clamp","differenceBy","differenceWith","findFrom","findIndexFrom","findLastFrom","findLastIndexFrom","getOr","includesFrom","indexOfFrom","inRange","intersectionBy","intersectionWith","invokeArgs","invokeArgsMap","isEqualWith","isMatchWith","flatMapDepth","lastIndexOfFrom","mergeWith","orderBy","padChars","padCharsEnd","padCharsStart","pullAllBy","pullAllWith","rangeStep","rangeStepRight","reduce","reduceRight","replace","set","slice","sortedIndexBy","sortedLastIndexBy","transform","unionBy","unionWith","update","xorBy","xorWith","zipWith"],4:["fill","setWith","updateWith"]},o.aryRearg={2:[1,0],3:[2,0,1],4:[3,2,0,1]},o.iterateeAry={dropRightWhile:1,dropWhile:1,every:1,filter:1,find:1,findFrom:1,findIndex:1,findIndexFrom:1,findKey:1,findLast:1,findLastFrom:1,findLastIndex:1,findLastIndexFrom:1,findLastKey:1,flatMap:1,flatMapDeep:1,flatMapDepth:1,forEach:1,forEachRight:1,forIn:1,forInRight:1,forOwn:1,forOwnRight:1,map:1,mapKeys:1,mapValues:1,partition:1,reduce:2,reduceRight:2,reject:1,remove:1,some:1,takeRightWhile:1,takeWhile:1,times:1,transform:2},o.iterateeRearg={mapKeys:[1],reduceRight:[1,0]},o.methodRearg={assignInAllWith:[1,0],assignInWith:[1,2,0],assignAllWith:[1,0],assignWith:[1,2,0],differenceBy:[1,2,0],differenceWith:[1,2,0],getOr:[2,1,0],intersectionBy:[1,2,0],intersectionWith:[1,2,0],isEqualWith:[1,2,0],isMatchWith:[2,1,0],mergeAllWith:[1,0],mergeWith:[1,2,0],padChars:[2,1,0],padCharsEnd:[2,1,0],padCharsStart:[2,1,0],pullAllBy:[2,1,0],pullAllWith:[2,1,0],rangeStep:[1,2,0],rangeStepRight:[1,2,0],setWith:[3,1,2,0],sortedIndexBy:[2,1,0],sortedLastIndexBy:[2,1,0],unionBy:[1,2,0],unionWith:[1,2,0],updateWith:[3,1,2,0],xorBy:[1,2,0],xorWith:[1,2,0],zipWith:[1,2,0]},o.methodSpread={assignAll:{start:0},assignAllWith:{start:0},assignInAll:{start:0},assignInAllWith:{start:0},defaultsAll:{start:0},defaultsDeepAll:{start:0},invokeArgs:{start:2},invokeArgsMap:{start:2},mergeAll:{start:0},mergeAllWith:{start:0},partial:{start:1},partialRight:{start:1},without:{start:1},zipAll:{start:0}},o.mutate={array:{fill:!0,pull:!0,pullAll:!0,pullAllBy:!0,pullAllWith:!0,pullAt:!0,remove:!0,reverse:!0},object:{assign:!0,assignAll:!0,assignAllWith:!0,assignIn:!0,assignInAll:!0,assignInAllWith:!0,assignInWith:!0,assignWith:!0,defaults:!0,defaultsAll:!0,defaultsDeep:!0,defaultsDeepAll:!0,merge:!0,mergeAll:!0,mergeAllWith:!0,mergeWith:!0},set:{set:!0,setWith:!0,unset:!0,update:!0,updateWith:!0}},o.realToAlias=function(){var s=Object.prototype.hasOwnProperty,i=o.aliasToReal,a={};for(var u in i){var _=i[u];s.call(a,_)?a[_].push(u):a[_]=[u]}return a}(),o.remap={assignAll:"assign",assignAllWith:"assignWith",assignInAll:"assignIn",assignInAllWith:"assignInWith",curryN:"curry",curryRightN:"curryRight",defaultsAll:"defaults",defaultsDeepAll:"defaultsDeep",findFrom:"find",findIndexFrom:"findIndex",findLastFrom:"findLast",findLastIndexFrom:"findLastIndex",getOr:"get",includesFrom:"includes",indexOfFrom:"indexOf",invokeArgs:"invoke",invokeArgsMap:"invokeMap",lastIndexOfFrom:"lastIndexOf",mergeAll:"merge",mergeAllWith:"mergeWith",padChars:"pad",padCharsEnd:"padEnd",padCharsStart:"padStart",propertyOf:"get",rangeStep:"range",rangeStepRight:"rangeRight",restFrom:"rest",spreadFrom:"spread",trimChars:"trim",trimCharsEnd:"trimEnd",trimCharsStart:"trimStart",zipAll:"zip"},o.skipFixed={castArray:!0,flow:!0,flowRight:!0,iteratee:!0,mixin:!0,rearg:!0,runInContext:!0},o.skipRearg={add:!0,assign:!0,assignIn:!0,bind:!0,bindKey:!0,concat:!0,difference:!0,divide:!0,eq:!0,gt:!0,gte:!0,isEqual:!0,lt:!0,lte:!0,matchesProperty:!0,merge:!0,multiply:!0,overArgs:!0,partial:!0,partialRight:!0,propertyOf:!0,random:!0,range:!0,rangeRight:!0,subtract:!0,zip:!0,zipObject:!0,zipObjectDeep:!0}},17255:(s,o,i)=>{var a=i(47422);s.exports=function basePropertyDeep(s){return function(o){return a(o,s)}}},17285:s=>{function source(s){return s?"string"==typeof s?s:s.source:null}function lookahead(s){return concat("(?=",s,")")}function concat(...s){return s.map((s=>source(s))).join("")}function either(...s){return"("+s.map((s=>source(s))).join("|")+")"}s.exports=function xml(s){const o=concat(/[A-Z_]/,function optional(s){return concat("(",s,")?")}(/[A-Z0-9_.-]*:/),/[A-Z0-9_.-]*/),i={className:"symbol",begin:/&[a-z]+;|&#[0-9]+;|&#x[a-f0-9]+;/},a={begin:/\s/,contains:[{className:"meta-keyword",begin:/#?[a-z_][a-z1-9_-]+/,illegal:/\n/}]},u=s.inherit(a,{begin:/\(/,end:/\)/}),_=s.inherit(s.APOS_STRING_MODE,{className:"meta-string"}),w=s.inherit(s.QUOTE_STRING_MODE,{className:"meta-string"}),x={endsWithParent:!0,illegal:/`]+/}]}]}]};return{name:"HTML, XML",aliases:["html","xhtml","rss","atom","xjb","xsd","xsl","plist","wsf","svg"],case_insensitive:!0,contains:[{className:"meta",begin://,relevance:10,contains:[a,w,_,u,{begin:/\[/,end:/\]/,contains:[{className:"meta",begin://,contains:[a,u,w,_]}]}]},s.COMMENT(//,{relevance:10}),{begin://,relevance:10},i,{className:"meta",begin:/<\?xml/,end:/\?>/,relevance:10},{className:"tag",begin:/)/,end:/>/,keywords:{name:"style"},contains:[x],starts:{end:/<\/style>/,returnEnd:!0,subLanguage:["css","xml"]}},{className:"tag",begin:/)/,end:/>/,keywords:{name:"script"},contains:[x],starts:{end:/<\/script>/,returnEnd:!0,subLanguage:["javascript","handlebars","xml"]}},{className:"tag",begin:/<>|<\/>/},{className:"tag",begin:concat(//,/>/,/\s/)))),end:/\/?>/,contains:[{className:"name",begin:o,relevance:0,starts:x}]},{className:"tag",begin:concat(/<\//,lookahead(concat(o,/>/))),contains:[{className:"name",begin:o,relevance:0},{begin:/>/,relevance:0,endsParent:!0}]}]}}},17400:(s,o,i)=>{var a=i(99374),u=1/0;s.exports=function toFinite(s){return s?(s=a(s))===u||s===-1/0?17976931348623157e292*(s<0?-1:1):s==s?s:0:0===s?s:0}},17533:s=>{s.exports=function yaml(s){var o="true false yes no null",i="[\\w#;/?:@&=+$,.~*'()[\\]]+",a={className:"string",relevance:0,variants:[{begin:/'/,end:/'/},{begin:/"/,end:/"/},{begin:/\S+/}],contains:[s.BACKSLASH_ESCAPE,{className:"template-variable",variants:[{begin:/\{\{/,end:/\}\}/},{begin:/%\{/,end:/\}/}]}]},u=s.inherit(a,{variants:[{begin:/'/,end:/'/},{begin:/"/,end:/"/},{begin:/[^\s,{}[\]]+/}]}),_={className:"number",begin:"\\b[0-9]{4}(-[0-9][0-9]){0,2}([Tt \\t][0-9][0-9]?(:[0-9][0-9]){2})?(\\.[0-9]*)?([ \\t])*(Z|[-+][0-9][0-9]?(:[0-9][0-9])?)?\\b"},w={end:",",endsWithParent:!0,excludeEnd:!0,keywords:o,relevance:0},x={begin:/\{/,end:/\}/,contains:[w],illegal:"\\n",relevance:0},C={begin:"\\[",end:"\\]",contains:[w],illegal:"\\n",relevance:0},j=[{className:"attr",variants:[{begin:"\\w[\\w :\\/.-]*:(?=[ \t]|$)"},{begin:'"\\w[\\w :\\/.-]*":(?=[ \t]|$)'},{begin:"'\\w[\\w :\\/.-]*':(?=[ \t]|$)"}]},{className:"meta",begin:"^---\\s*$",relevance:10},{className:"string",begin:"[\\|>]([1-9]?[+-])?[ ]*\\n( +)[^ ][^\\n]*\\n(\\2[^\\n]+\\n?)*"},{begin:"<%[%=-]?",end:"[%-]?%>",subLanguage:"ruby",excludeBegin:!0,excludeEnd:!0,relevance:0},{className:"type",begin:"!\\w+!"+i},{className:"type",begin:"!<"+i+">"},{className:"type",begin:"!"+i},{className:"type",begin:"!!"+i},{className:"meta",begin:"&"+s.UNDERSCORE_IDENT_RE+"$"},{className:"meta",begin:"\\*"+s.UNDERSCORE_IDENT_RE+"$"},{className:"bullet",begin:"-(?=[ ]|$)",relevance:0},s.HASH_COMMENT_MODE,{beginKeywords:o,keywords:{literal:o}},_,{className:"number",begin:s.C_NUMBER_RE+"\\b",relevance:0},x,C,a],L=[...j];return L.pop(),L.push(u),w.contains=L,{name:"YAML",case_insensitive:!0,aliases:["yml"],contains:j}}},17670:(s,o,i)=>{var a=i(12651);s.exports=function mapCacheDelete(s){var o=a(this,s).delete(s);return this.size-=o?1:0,o}},17965:(s,o,i)=>{"use strict";var a=i(16426),u={"text/plain":"Text","text/html":"Url",default:"Text"};s.exports=function copy(s,o){var i,_,w,x,C,j,L=!1;o||(o={}),i=o.debug||!1;try{if(w=a(),x=document.createRange(),C=document.getSelection(),(j=document.createElement("span")).textContent=s,j.ariaHidden="true",j.style.all="unset",j.style.position="fixed",j.style.top=0,j.style.clip="rect(0, 0, 0, 0)",j.style.whiteSpace="pre",j.style.webkitUserSelect="text",j.style.MozUserSelect="text",j.style.msUserSelect="text",j.style.userSelect="text",j.addEventListener("copy",(function(a){if(a.stopPropagation(),o.format)if(a.preventDefault(),void 0===a.clipboardData){i&&console.warn("unable to use e.clipboardData"),i&&console.warn("trying IE specific stuff"),window.clipboardData.clearData();var _=u[o.format]||u.default;window.clipboardData.setData(_,s)}else a.clipboardData.clearData(),a.clipboardData.setData(o.format,s);o.onCopy&&(a.preventDefault(),o.onCopy(a.clipboardData))})),document.body.appendChild(j),x.selectNodeContents(j),C.addRange(x),!document.execCommand("copy"))throw new Error("copy command was unsuccessful");L=!0}catch(a){i&&console.error("unable to copy using execCommand: ",a),i&&console.warn("trying IE specific stuff");try{window.clipboardData.setData(o.format||"text",s),o.onCopy&&o.onCopy(window.clipboardData),L=!0}catch(a){i&&console.error("unable to copy using clipboardData: ",a),i&&console.error("falling back to prompt"),_=function format(s){var o=(/mac os x/i.test(navigator.userAgent)?"⌘":"Ctrl")+"+C";return s.replace(/#{\s*key\s*}/g,o)}("message"in o?o.message:"Copy to clipboard: #{key}, Enter"),window.prompt(_,s)}}finally{C&&("function"==typeof C.removeRange?C.removeRange(x):C.removeAllRanges()),j&&document.body.removeChild(j),w()}return L}},18073:(s,o,i)=>{var a=i(85087),u=i(54641),_=i(70981);s.exports=function createRecurry(s,o,i,w,x,C,j,L,B,$){var U=8&o;o|=U?32:64,4&(o&=~(U?64:32))||(o&=-4);var V=[s,o,x,U?C:void 0,U?j:void 0,U?void 0:C,U?void 0:j,L,B,$],z=i.apply(void 0,V);return a(s)&&u(z,V),z.placeholder=w,_(z,s,o)}},19123:(s,o,i)=>{var a=i(65606),u=i(31499),_=i(88310).Stream;function resolve(s,o,i){var a,_=function create_indent(s,o){return new Array(o||0).join(s||"")}(o,i=i||0),w=s;if("object"==typeof s&&((w=s[a=Object.keys(s)[0]])&&w._elem))return w._elem.name=a,w._elem.icount=i,w._elem.indent=o,w._elem.indents=_,w._elem.interrupt=w,w._elem;var x,C=[],j=[];function get_attributes(s){Object.keys(s).forEach((function(o){C.push(function attribute(s,o){return s+'="'+u(o)+'"'}(o,s[o]))}))}switch(typeof w){case"object":if(null===w)break;w._attr&&get_attributes(w._attr),w._cdata&&j.push(("/g,"]]]]>")+"]]>"),w.forEach&&(x=!1,j.push(""),w.forEach((function(s){"object"==typeof s?"_attr"==Object.keys(s)[0]?get_attributes(s._attr):j.push(resolve(s,o,i+1)):(j.pop(),x=!0,j.push(u(s)))})),x||j.push(""));break;default:j.push(u(w))}return{name:a,interrupt:!1,attributes:C,content:j,icount:i,indents:_,indent:o}}function format(s,o,i){if("object"!=typeof o)return s(!1,o);var a=o.interrupt?1:o.content.length;function proceed(){for(;o.content.length;){var u=o.content.shift();if(void 0!==u){if(interrupt(u))return;format(s,u)}}s(!1,(a>1?o.indents:"")+(o.name?"":"")+(o.indent&&!i?"\n":"")),i&&i()}function interrupt(o){return!!o.interrupt&&(o.interrupt.append=s,o.interrupt.end=proceed,o.interrupt=!1,s(!0),!0)}if(s(!1,o.indents+(o.name?"<"+o.name:"")+(o.attributes.length?" "+o.attributes.join(" "):"")+(a?o.name?">":"":o.name?"/>":"")+(o.indent&&a>1?"\n":"")),!a)return s(!1,o.indent?"\n":"");interrupt(o)||proceed()}s.exports=function xml(s,o){"object"!=typeof o&&(o={indent:o});var i=o.stream?new _:null,u="",w=!1,x=o.indent?!0===o.indent?" ":o.indent:"",C=!0;function delay(s){C?a.nextTick(s):s()}function append(s,o){if(void 0!==o&&(u+=o),s&&!w&&(i=i||new _,w=!0),s&&w){var a=u;delay((function(){i.emit("data",a)})),u=""}}function add(s,o){format(append,resolve(s,x,x?1:0),o)}function end(){if(i){var s=u;delay((function(){i.emit("data",s),i.emit("end"),i.readable=!1,i.emit("close")}))}}return delay((function(){C=!1})),o.declaration&&function addXmlDeclaration(s){var o={version:"1.0",encoding:s.encoding||"UTF-8"};s.standalone&&(o.standalone=s.standalone),add({"?xml":{_attr:o}}),u=u.replace("/>","?>")}(o.declaration),s&&s.forEach?s.forEach((function(o,i){var a;i+1===s.length&&(a=end),add(o,a)})):add(s,end),i?(i.readable=!0,i):u},s.exports.element=s.exports.Element=function element(){var s={_elem:resolve(Array.prototype.slice.call(arguments)),push:function(s){if(!this.append)throw new Error("not assigned to a parent!");var o=this,i=this._elem.indent;format(this.append,resolve(s,i,this._elem.icount+(i?1:0)),(function(){o.append(!0)}))},close:function(s){void 0!==s&&this.push(s),this.end&&this.end()}};return s}},19219:s=>{s.exports=function cacheHas(s,o){return s.has(o)}},19287:s=>{"use strict";s.exports={CSSRuleList:0,CSSStyleDeclaration:0,CSSValueList:0,ClientRectList:0,DOMRectList:0,DOMStringList:0,DOMTokenList:1,DataTransferItemList:0,FileList:0,HTMLAllCollection:0,HTMLCollection:0,HTMLFormElement:0,HTMLSelectElement:0,MediaList:0,MimeTypeArray:0,NamedNodeMap:0,NodeList:1,PaintRequestList:0,Plugin:0,PluginArray:0,SVGLengthList:0,SVGNumberList:0,SVGPathSegList:0,SVGPointList:0,SVGStringList:0,SVGTransformList:0,SourceBufferList:0,StyleSheetList:0,TextTrackCueList:0,TextTrackList:0,TouchList:0}},19358:(s,o,i)=>{"use strict";var a=i(85582),u=i(49724),_=i(61626),w=i(88280),x=i(79192),C=i(19595),j=i(54829),L=i(34084),B=i(32096),$=i(39259),U=i(85884),V=i(39447),z=i(7376);s.exports=function(s,o,i,Y){var Z="stackTraceLimit",ee=Y?2:1,ie=s.split("."),ae=ie[ie.length-1],ce=a.apply(null,ie);if(ce){var le=ce.prototype;if(!z&&u(le,"cause")&&delete le.cause,!i)return ce;var pe=a("Error"),de=o((function(s,o){var i=B(Y?o:s,void 0),a=Y?new ce(s):new ce;return void 0!==i&&_(a,"message",i),U(a,de,a.stack,2),this&&w(le,this)&&L(a,this,de),arguments.length>ee&&$(a,arguments[ee]),a}));if(de.prototype=le,"Error"!==ae?x?x(de,pe):C(de,pe,{name:!0}):V&&Z in ce&&(j(de,ce,Z),j(de,ce,"prepareStackTrace")),C(de,ce),!z)try{le.name!==ae&&_(le,"name",ae),le.constructor=de}catch(s){}return de}}},19570:(s,o,i)=>{var a=i(37334),u=i(93243),_=i(83488),w=u?function(s,o){return u(s,"toString",{configurable:!0,enumerable:!1,value:a(o),writable:!0})}:_;s.exports=w},19595:(s,o,i)=>{"use strict";var a=i(49724),u=i(11042),_=i(13846),w=i(74284);s.exports=function(s,o,i){for(var x=u(o),C=w.f,j=_.f,L=0;L{"use strict";var a=i(23034);s.exports=a},19846:(s,o,i)=>{"use strict";var a=i(20798),u=i(98828),_=i(45951).String;s.exports=!!Object.getOwnPropertySymbols&&!u((function(){var s=Symbol("symbol detection");return!_(s)||!(Object(s)instanceof Symbol)||!Symbol.sham&&a&&a<41}))},19931:(s,o,i)=>{var a=i(31769),u=i(68090),_=i(68969),w=i(77797);s.exports=function baseUnset(s,o){return o=a(o,s),null==(s=_(s,o))||delete s[w(u(o))]}},20181:(s,o,i)=>{var a=/^\s+|\s+$/g,u=/^[-+]0x[0-9a-f]+$/i,_=/^0b[01]+$/i,w=/^0o[0-7]+$/i,x=parseInt,C="object"==typeof i.g&&i.g&&i.g.Object===Object&&i.g,j="object"==typeof self&&self&&self.Object===Object&&self,L=C||j||Function("return this")(),B=Object.prototype.toString,$=Math.max,U=Math.min,now=function(){return L.Date.now()};function isObject(s){var o=typeof s;return!!s&&("object"==o||"function"==o)}function toNumber(s){if("number"==typeof s)return s;if(function isSymbol(s){return"symbol"==typeof s||function isObjectLike(s){return!!s&&"object"==typeof s}(s)&&"[object Symbol]"==B.call(s)}(s))return NaN;if(isObject(s)){var o="function"==typeof s.valueOf?s.valueOf():s;s=isObject(o)?o+"":o}if("string"!=typeof s)return 0===s?s:+s;s=s.replace(a,"");var i=_.test(s);return i||w.test(s)?x(s.slice(2),i?2:8):u.test(s)?NaN:+s}s.exports=function debounce(s,o,i){var a,u,_,w,x,C,j=0,L=!1,B=!1,V=!0;if("function"!=typeof s)throw new TypeError("Expected a function");function invokeFunc(o){var i=a,_=u;return a=u=void 0,j=o,w=s.apply(_,i)}function shouldInvoke(s){var i=s-C;return void 0===C||i>=o||i<0||B&&s-j>=_}function timerExpired(){var s=now();if(shouldInvoke(s))return trailingEdge(s);x=setTimeout(timerExpired,function remainingWait(s){var i=o-(s-C);return B?U(i,_-(s-j)):i}(s))}function trailingEdge(s){return x=void 0,V&&a?invokeFunc(s):(a=u=void 0,w)}function debounced(){var s=now(),i=shouldInvoke(s);if(a=arguments,u=this,C=s,i){if(void 0===x)return function leadingEdge(s){return j=s,x=setTimeout(timerExpired,o),L?invokeFunc(s):w}(C);if(B)return x=setTimeout(timerExpired,o),invokeFunc(C)}return void 0===x&&(x=setTimeout(timerExpired,o)),w}return o=toNumber(o)||0,isObject(i)&&(L=!!i.leading,_=(B="maxWait"in i)?$(toNumber(i.maxWait)||0,o):_,V="trailing"in i?!!i.trailing:V),debounced.cancel=function cancel(){void 0!==x&&clearTimeout(x),j=0,a=C=u=x=void 0},debounced.flush=function flush(){return void 0===x?w:trailingEdge(now())},debounced}},20317:s=>{s.exports=function mapToArray(s){var o=-1,i=Array(s.size);return s.forEach((function(s,a){i[++o]=[a,s]})),i}},20334:(s,o,i)=>{"use strict";var a=i(48287).Buffer;class NonError extends Error{constructor(s){super(NonError._prepareSuperMessage(s)),Object.defineProperty(this,"name",{value:"NonError",configurable:!0,writable:!0}),Error.captureStackTrace&&Error.captureStackTrace(this,NonError)}static _prepareSuperMessage(s){try{return JSON.stringify(s)}catch{return String(s)}}}const u=[{property:"name",enumerable:!1},{property:"message",enumerable:!1},{property:"stack",enumerable:!1},{property:"code",enumerable:!0}],_=Symbol(".toJSON called"),destroyCircular=({from:s,seen:o,to_:i,forceEnumerable:w,maxDepth:x,depth:C})=>{const j=i||(Array.isArray(s)?[]:{});if(o.push(s),C>=x)return j;if("function"==typeof s.toJSON&&!0!==s[_])return(s=>{s[_]=!0;const o=s.toJSON();return delete s[_],o})(s);for(const[i,u]of Object.entries(s))"function"==typeof a&&a.isBuffer(u)?j[i]="[object Buffer]":"function"!=typeof u&&(u&&"object"==typeof u?o.includes(s[i])?j[i]="[Circular]":(C++,j[i]=destroyCircular({from:s[i],seen:o.slice(),forceEnumerable:w,maxDepth:x,depth:C})):j[i]=u);for(const{property:o,enumerable:i}of u)"string"==typeof s[o]&&Object.defineProperty(j,o,{value:s[o],enumerable:!!w||i,configurable:!0,writable:!0});return j};s.exports={serializeError:(s,o={})=>{const{maxDepth:i=Number.POSITIVE_INFINITY}=o;return"object"==typeof s&&null!==s?destroyCircular({from:s,seen:[],forceEnumerable:!0,maxDepth:i,depth:0}):"function"==typeof s?`[Function: ${s.name||"anonymous"}]`:s},deserializeError:(s,o={})=>{const{maxDepth:i=Number.POSITIVE_INFINITY}=o;if(s instanceof Error)return s;if("object"==typeof s&&null!==s&&!Array.isArray(s)){const o=new Error;return destroyCircular({from:s,seen:[],to_:o,maxDepth:i,depth:0}),o}return new NonError(s)}}},20426:s=>{var o=Object.prototype.hasOwnProperty;s.exports=function baseHas(s,i){return null!=s&&o.call(s,i)}},20575:(s,o,i)=>{"use strict";var a=i(3121);s.exports=function(s){return a(s.length)}},20798:(s,o,i)=>{"use strict";var a,u,_=i(45951),w=i(96794),x=_.process,C=_.Deno,j=x&&x.versions||C&&C.version,L=j&&j.v8;L&&(u=(a=L.split("."))[0]>0&&a[0]<4?1:+(a[0]+a[1])),!u&&w&&(!(a=w.match(/Edge\/(\d+)/))||a[1]>=74)&&(a=w.match(/Chrome\/(\d+)/))&&(u=+a[1]),s.exports=u},20850:(s,o,i)=>{"use strict";s.exports=i(46076)},20999:(s,o,i)=>{var a=i(69302),u=i(36800);s.exports=function createAssigner(s){return a((function(o,i){var a=-1,_=i.length,w=_>1?i[_-1]:void 0,x=_>2?i[2]:void 0;for(w=s.length>3&&"function"==typeof w?(_--,w):void 0,x&&u(i[0],i[1],x)&&(w=_<3?void 0:w,_=1),o=Object(o);++a<_;){var C=i[a];C&&s(o,C,a,w)}return o}))}},21549:(s,o,i)=>{var a=i(22032),u=i(63862),_=i(66721),w=i(12749),x=i(35749);function Hash(s){var o=-1,i=null==s?0:s.length;for(this.clear();++o{var a=i(16547),u=i(43360);s.exports=function copyObject(s,o,i,_){var w=!i;i||(i={});for(var x=-1,C=o.length;++x{var a=i(51873),u=i(37828),_=i(75288),w=i(25911),x=i(20317),C=i(84247),j=a?a.prototype:void 0,L=j?j.valueOf:void 0;s.exports=function equalByTag(s,o,i,a,j,B,$){switch(i){case"[object DataView]":if(s.byteLength!=o.byteLength||s.byteOffset!=o.byteOffset)return!1;s=s.buffer,o=o.buffer;case"[object ArrayBuffer]":return!(s.byteLength!=o.byteLength||!B(new u(s),new u(o)));case"[object Boolean]":case"[object Date]":case"[object Number]":return _(+s,+o);case"[object Error]":return s.name==o.name&&s.message==o.message;case"[object RegExp]":case"[object String]":return s==o+"";case"[object Map]":var U=x;case"[object Set]":var V=1&a;if(U||(U=C),s.size!=o.size&&!V)return!1;var z=$.get(s);if(z)return z==o;a|=2,$.set(s,o);var Y=w(U(s),U(o),a,j,B,$);return $.delete(s),Y;case"[object Symbol]":if(L)return L.call(s)==L.call(o)}return!1}},22032:(s,o,i)=>{var a=i(81042);s.exports=function hashClear(){this.__data__=a?a(null):{},this.size=0}},22225:s=>{var o="\\ud800-\\udfff",i="\\u2700-\\u27bf",a="a-z\\xdf-\\xf6\\xf8-\\xff",u="A-Z\\xc0-\\xd6\\xd8-\\xde",_="\\xac\\xb1\\xd7\\xf7\\x00-\\x2f\\x3a-\\x40\\x5b-\\x60\\x7b-\\xbf\\u2000-\\u206f \\t\\x0b\\f\\xa0\\ufeff\\n\\r\\u2028\\u2029\\u1680\\u180e\\u2000\\u2001\\u2002\\u2003\\u2004\\u2005\\u2006\\u2007\\u2008\\u2009\\u200a\\u202f\\u205f\\u3000",w="["+_+"]",x="\\d+",C="["+i+"]",j="["+a+"]",L="[^"+o+_+x+i+a+u+"]",B="(?:\\ud83c[\\udde6-\\uddff]){2}",$="[\\ud800-\\udbff][\\udc00-\\udfff]",U="["+u+"]",V="(?:"+j+"|"+L+")",z="(?:"+U+"|"+L+")",Y="(?:['’](?:d|ll|m|re|s|t|ve))?",Z="(?:['’](?:D|LL|M|RE|S|T|VE))?",ee="(?:[\\u0300-\\u036f\\ufe20-\\ufe2f\\u20d0-\\u20ff]|\\ud83c[\\udffb-\\udfff])?",ie="[\\ufe0e\\ufe0f]?",ae=ie+ee+("(?:\\u200d(?:"+["[^"+o+"]",B,$].join("|")+")"+ie+ee+")*"),ce="(?:"+[C,B,$].join("|")+")"+ae,le=RegExp([U+"?"+j+"+"+Y+"(?="+[w,U,"$"].join("|")+")",z+"+"+Z+"(?="+[w,U+V,"$"].join("|")+")",U+"?"+V+"+"+Y,U+"+"+Z,"\\d*(?:1ST|2ND|3RD|(?![123])\\dTH)(?=\\b|[a-z_])","\\d*(?:1st|2nd|3rd|(?![123])\\dth)(?=\\b|[A-Z_])",x,ce].join("|"),"g");s.exports=function unicodeWords(s){return s.match(le)||[]}},22551:(s,o,i)=>{"use strict";var a=i(96540),u=i(69982);function p(s){for(var o="https://reactjs.org/docs/error-decoder.html?invariant="+s,i=1;i